From 70799b22d62f735871259b1bfe7c46e946f531ac Mon Sep 17 00:00:00 2001 From: rlronan <37887237+rlronan@users.noreply.github.com> Date: Tue, 18 Feb 2020 01:09:53 -0500 Subject: [PATCH 1/6] Update Arena.py See: https://github.com/suragnair/alpha-zero-general/issues/156 --- Arena.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Arena.py b/Arena.py index d741e7b6e..15c19445d 100644 --- a/Arena.py +++ b/Arena.py @@ -55,7 +55,7 @@ def playGame(self, verbose=False): assert(self.display) print("Game over: Turn ", str(it), "Result ", str(self.game.getGameEnded(board, 1))) self.display(board) - return self.game.getGameEnded(board, 1) + return curPlayer*self.game.getGameEnded(board, curPlayer) def playGames(self, num, verbose=False): """ From 76e716e201a6e42132b1b8eb875a902d950077b6 Mon Sep 17 00:00:00 2001 From: rlronan <37887237+rlronan@users.noreply.github.com> Date: Sun, 8 Mar 2020 12:22:55 -0400 Subject: [PATCH 2/6] Save valid moves with training examples --- Coach.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/Coach.py b/Coach.py index 8bf355c27..59af73b16 100644 --- a/Coach.py +++ b/Coach.py @@ -50,8 +50,15 @@ def executeEpisode(self): pi = self.mcts.getActionProb(canonicalBoard, temp=temp) sym = self.game.getSymmetries(canonicalBoard, pi) - for b,p in sym: - trainExamples.append([b, self.curPlayer, p, None]) + + # ideally these should be combined so that getSymmetries takes valids as well + bs, ps = zip(*self.game.getSymmetries(canonicalBoard, pi)) + _, valids_sym = zip(*self.game.getSymmetries(canonicalBoard, valids)) + sym = zip(bs,ps,valids_sym) + + for b,p,valid in sym: + # previous was: [b, self.curPlayer, p, None], but only 3 values were returned + trainExamples.append([b, self.curPlayer, p, valid]) action = np.random.choice(len(pi), p=pi) board, self.curPlayer = self.game.getNextState(board, self.curPlayer, action) @@ -59,7 +66,7 @@ def executeEpisode(self): r = self.game.getGameEnded(board, self.curPlayer) if r!=0: - return [(x[0],x[2],r*((-1)**(x[1]!=self.curPlayer))) for x in trainExamples] + return [(x[0],x[2],r*((-1)**(x[1]!=self.curPlayer)),x[3]) for x in trainExamples] def learn(self): """ From 2008ae967b8b27024ac6c25ec0deb93e8b413fd5 Mon Sep 17 00:00:00 2001 From: rlronan <37887237+rlronan@users.noreply.github.com> Date: Sun, 8 Mar 2020 12:25:12 -0400 Subject: [PATCH 3/6] Update Coach.py --- Coach.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Coach.py b/Coach.py index 59af73b16..6d2aeabda 100644 --- a/Coach.py +++ b/Coach.py @@ -49,7 +49,7 @@ def executeEpisode(self): temp = int(episodeStep < self.args.tempThreshold) pi = self.mcts.getActionProb(canonicalBoard, temp=temp) - sym = self.game.getSymmetries(canonicalBoard, pi) + valids = self.game.getValidMoves(canonicalBoard, 1) # ideally these should be combined so that getSymmetries takes valids as well bs, ps = zip(*self.game.getSymmetries(canonicalBoard, pi)) From 80a43f2c744d8e1d342ed2b463f7b306a9bc4526 Mon Sep 17 00:00:00 2001 From: rlronan <37887237+rlronan@users.noreply.github.com> Date: Sun, 8 Mar 2020 12:29:03 -0400 Subject: [PATCH 4/6] Update NNet to receive and pass valids --- othello/tensorflow/NNet.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/othello/tensorflow/NNet.py b/othello/tensorflow/NNet.py index 5325ac3a9..d06e0802b 100644 --- a/othello/tensorflow/NNet.py +++ b/othello/tensorflow/NNet.py @@ -52,10 +52,10 @@ def train(self, examples): # self.sess.run(tf.local_variables_initializer()) while batch_idx < int(len(examples)/args.batch_size): sample_ids = np.random.randint(len(examples), size=args.batch_size) - boards, pis, vs = list(zip(*[examples[i] for i in sample_ids])) + boards, pis, vs, valids = list(zip(*[examples[i] for i in sample_ids])) # predict and compute gradient and do SGD step - input_dict = {self.nnet.input_boards: boards, self.nnet.target_pis: pis, self.nnet.target_vs: vs, self.nnet.dropout: args.dropout, self.nnet.isTraining: True} + input_dict = {self.nnet.input_boards: boards, self.nnet.target_pis: pis, self.nnet.target_vs: vs, self.nnet.valids: valids, self.nnet.dropout: args.dropout, self.nnet.isTraining: True} # measure data loading time data_time.update(time.time() - end) @@ -86,18 +86,21 @@ def train(self, examples): bar.finish() - def predict(self, board): + def predict(self, boardAndValids): """ board: np array with board """ # timing start = time.time() + board, valids = boardAndValids + # preparing input board = board[np.newaxis, :, :] + valids = valids[np.newaxis, :] # run - prob, v = self.sess.run([self.nnet.prob, self.nnet.v], feed_dict={self.nnet.input_boards: board, self.nnet.dropout: 0, self.nnet.isTraining: False}) + prob, v = self.sess.run([self.nnet.prob, self.nnet.v], feed_dict={self.nnet.input_boards: board, self.nnet.valids: valids, self.nnet.dropout: 0, self.nnet.isTraining: False}) #print('PREDICTION TIME TAKEN : {0:03f}'.format(time.time()-start)) return prob[0], v[0] @@ -120,4 +123,4 @@ def load_checkpoint(self, folder='checkpoint', filename='checkpoint.pth.tar'): raise("No model in path {}".format(filepath)) with self.nnet.graph.as_default(): self.saver = tf.train.Saver() - self.saver.restore(self.sess, filepath) \ No newline at end of file + self.saver.restore(self.sess, filepath) From e1d1158c8114c4e53efc0cbe3d915567ce26c909 Mon Sep 17 00:00:00 2001 From: rlronan <37887237+rlronan@users.noreply.github.com> Date: Sun, 8 Mar 2020 12:36:05 -0400 Subject: [PATCH 5/6] Net now recieves and filters valid moves --- othello/tensorflow/OthelloNNet.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/othello/tensorflow/OthelloNNet.py b/othello/tensorflow/OthelloNNet.py index 92f28c5c4..72cc064fb 100644 --- a/othello/tensorflow/OthelloNNet.py +++ b/othello/tensorflow/OthelloNNet.py @@ -23,6 +23,7 @@ def __init__(self, game, args): with self.graph.as_default(): self.input_boards = tf.placeholder(tf.float32, shape=[None, self.board_x, self.board_y]) # s: batch_size x board_x x board_y self.dropout = tf.placeholder(tf.float32) + self.valids = tf.placeholder(tf.float32, shape=[None, self.action_size]) self.isTraining = tf.placeholder(tf.bool, name="is_training") x_image = tf.reshape(self.input_boards, [-1, self.board_x, self.board_y, 1]) # batch_size x board_x x board_y x 1 @@ -34,6 +35,11 @@ def __init__(self, game, args): s_fc1 = Dropout(Relu(BatchNormalization(Dense(h_conv4_flat, 1024, use_bias=False), axis=1, training=self.isTraining)), rate=self.dropout) # batch_size x 1024 s_fc2 = Dropout(Relu(BatchNormalization(Dense(s_fc1, 512, use_bias=False), axis=1, training=self.isTraining)), rate=self.dropout) # batch_size x 512 self.pi = Dense(s_fc2, self.action_size) # batch_size x self.action_size + + # this sets the value of invalid moves to ~ -1000, so that the network is not encouraged to set both + # illegal moves, and low-quality moves to a value of 0, as sigmoid(-1000) approx= 0. + # see: https://github.com/suragnair/alpha-zero-general/issues/77 + self.pi -= (1-self.valids)*1000 self.prob = tf.nn.softmax(self.pi) self.v = Tanh(Dense(s_fc2, 1)) # batch_size x 1 @@ -64,6 +70,7 @@ def __init__(self, game, args): with self.graph.as_default(): self.input_boards = tf.placeholder(tf.float32, shape=[None, self.board_x, self.board_y]) # s: batch_size x board_x x board_y self.dropout = tf.placeholder(tf.float32) + self.valids = tf.placeholder(tf.float32, shape=[None, self.action_size]) self.isTraining = tf.placeholder(tf.bool, name="is_training") x_image = tf.reshape(self.input_boards, [-1, self.board_x, self.board_y, 1]) # batch_size x board_x x board_y x 1 @@ -96,6 +103,11 @@ def __init__(self, game, args): policy = tf.nn.relu(policy) policy = tf.layers.flatten(policy, name='p_flatten') self.pi = tf.layers.dense(policy, self.action_size) + + # this sets the value of invalid moves to ~ -1000, so that the network is not encouraged to set both + # illegal moves, and low-quality moves to a value of 0. + # see: https://github.com/suragnair/alpha-zero-general/issues/77 + self.pi -= (1-self.valids)*1000 self.prob = tf.nn.softmax(self.pi) value = tf.layers.conv2d(residual_tower, 1,kernel_size=(1, 1), strides=(1, 1),name='v',padding='same',use_bias=False) From 5814067f20c32222e51ee41caf871d89a2d9a69d Mon Sep 17 00:00:00 2001 From: rlronan <37887237+rlronan@users.noreply.github.com> Date: Sun, 8 Mar 2020 12:40:14 -0400 Subject: [PATCH 6/6] Update MCTS.py --- MCTS.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/MCTS.py b/MCTS.py index 414133e58..f76eaee4c 100644 --- a/MCTS.py +++ b/MCTS.py @@ -76,8 +76,10 @@ def search(self, canonicalBoard): if s not in self.Ps: # leaf node - self.Ps[s], v = self.nnet.predict(canonicalBoard) + + # see: https://github.com/suragnair/alpha-zero-general/issues/77 valids = self.game.getValidMoves(canonicalBoard, 1) + self.Ps[s], v = self.nnet.predict((canonicalBoard, valids)) self.Ps[s] = self.Ps[s]*valids # masking invalid moves sum_Ps_s = np.sum(self.Ps[s]) if sum_Ps_s > 0: