Skip to content

Commit

Permalink
Refactor DDPG (openai#111)
Browse files Browse the repository at this point in the history
* run ddpg on Mujoco benchmark RUN BENCHMARKS

* autopep8

* fixed all syntax in refactored ddpg

* a little bit more refactoring

* autopep8

* identity test with ddpg WIP

* enable test_identity with ddpg

* refactored ddpg RUN BENCHMARKS

* autopep8

* include ddpg into style check

* fixing tests RUN BENCHMARKS

* set default seed to None RUN BENCHMARKS

* run tests and benchmarks in separate buildkite steps RUN BENCHMARKS

* cleanup pdb usage

* flake8 and cleanups

* re-enabled all benchmarks in run-benchmarks-new.py

* flake8 complaints

* deepq model builder compatible with network functions returning single tensor

* remove ddpg test with test_discrete_identity

* make ppo_metal use make_vec_env instead of make_atari_env

* make ppo_metal use make_vec_env instead of make_atari_env

* fixed syntax in ppo_metal.run_atari
  • Loading branch information
pzhokhov committed Oct 3, 2018
1 parent 8a562e8 commit 367a82a
Show file tree
Hide file tree
Showing 10 changed files with 664 additions and 745 deletions.
20 changes: 13 additions & 7 deletions baselines/common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def nature_cnn(unscaled_images, **conv_kwargs):


@register("mlp")
def mlp(num_layers=2, num_hidden=64, activation=tf.tanh):
def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False):
"""
Stack of fully-connected layers to be used in a policy / q-function approximator
Expand All @@ -49,16 +49,20 @@ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh):
def network_fn(X):
h = tf.layers.flatten(X)
for i in range(num_layers):
h = activation(fc(h, 'mlp_fc{}'.format(i), nh=num_hidden, init_scale=np.sqrt(2)))
return h, None
h = fc(h, 'mlp_fc{}'.format(i), nh=num_hidden, init_scale=np.sqrt(2))
if layer_norm:
h = tf.contrib.layers.layer_norm(h, center=True, scale=True)
h = activation(h)

return h

return network_fn


@register("cnn")
def cnn(**conv_kwargs):
def network_fn(X):
return nature_cnn(X, **conv_kwargs), None
return nature_cnn(X, **conv_kwargs)
return network_fn


Expand All @@ -72,7 +76,7 @@ def network_fn(X):
h = activ(conv(h, 'c2', nf=16, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs))
h = conv_to_fc(h)
h = activ(fc(h, 'fc1', nh=128, init_scale=np.sqrt(2)))
return h, None
return h
return network_fn


Expand Down Expand Up @@ -190,7 +194,7 @@ def network_fn(X):
activation_fn=tf.nn.relu,
**conv_kwargs)

return out, None
return out
return network_fn

def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]):
Expand All @@ -212,7 +216,9 @@ def your_network_define(**net_kwargs):
return network_fn
"""
if name in mapping:
if callable(name):
return name
elif name in mapping:
return mapping[name]
else:
raise ValueError('Unknown network type: {}'.format(name))
19 changes: 11 additions & 8 deletions baselines/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,16 @@ def policy_fn(nbatch=None, nsteps=None, sess=None, observ_placeholder=None):
encoded_x = encode_observation(ob_space, encoded_x)

with tf.variable_scope('pi', reuse=tf.AUTO_REUSE):
policy_latent, recurrent_tensors = policy_network(encoded_x)
policy_latent = policy_network(encoded_x)
if isinstance(policy_latent, tuple):
policy_latent, recurrent_tensors = policy_latent

if recurrent_tensors is not None:
# recurrent architecture, need a few more steps
nenv = nbatch // nsteps
assert nenv > 0, 'Bad input for recurrent policy: batch size {} smaller than nsteps {}'.format(nbatch, nsteps)
policy_latent, recurrent_tensors = policy_network(encoded_x, nenv)
extra_tensors.update(recurrent_tensors)
if recurrent_tensors is not None:
# recurrent architecture, need a few more steps
nenv = nbatch // nsteps
assert nenv > 0, 'Bad input for recurrent policy: batch size {} smaller than nsteps {}'.format(nbatch, nsteps)
policy_latent, recurrent_tensors = policy_network(encoded_x, nenv)
extra_tensors.update(recurrent_tensors)


_v_net = value_network
Expand All @@ -160,7 +162,8 @@ def policy_fn(nbatch=None, nsteps=None, sess=None, observ_placeholder=None):
assert callable(_v_net)

with tf.variable_scope('vf', reuse=tf.AUTO_REUSE):
vf_latent, _ = _v_net(encoded_x)
# TODO recurrent architectures are not supported with value_network=copy yet
vf_latent = _v_net(encoded_x)

policy = PolicyWithValue(
env=env,
Expand Down
10 changes: 7 additions & 3 deletions baselines/common/tests/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@
'a2c' : {},
'acktr': {},
'deepq': {},
'ddpg': dict(nb_epochs=None, layer_norm=True),
'ppo2': dict(lr=1e-3, nsteps=64, ent_coef=0.0),
'trpo_mpi': dict(timesteps_per_batch=100, cg_iters=10, gamma=0.9, lam=1.0, max_kl=0.01)
}


algos_disc = ['a2c', 'deepq', 'ppo2', 'trpo_mpi']
algos_cont = ['a2c', 'ddpg', 'ppo2', 'trpo_mpi']

@pytest.mark.slow
@pytest.mark.parametrize("alg", learn_kwargs.keys())
@pytest.mark.parametrize("alg", algos_disc)
def test_discrete_identity(alg):
'''
Test if the algorithm (with an mlp policy)
Expand All @@ -35,7 +39,7 @@ def test_discrete_identity(alg):
simple_test(env_fn, learn_fn, 0.9)

@pytest.mark.slow
@pytest.mark.parametrize("alg", ['a2c', 'ppo2', 'trpo_mpi'])
@pytest.mark.parametrize("alg", algos_cont)
def test_continuous_identity(alg):
'''
Test if the algorithm (with an mlp policy)
Expand All @@ -51,5 +55,5 @@ def test_continuous_identity(alg):
simple_test(env_fn, learn_fn, -0.1)

if __name__ == '__main__':
test_continuous_identity('a2c')
test_continuous_identity('ddpg')

Loading

0 comments on commit 367a82a

Please sign in to comment.