diff --git a/baselines/ppo2/ppo2.py b/baselines/ppo2/ppo2.py index d8ef42a5d5..9a1300385c 100644 --- a/baselines/ppo2/ppo2.py +++ b/baselines/ppo2/ppo2.py @@ -218,7 +218,7 @@ def f(_): return val return f -def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4, +def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4, vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95, log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2, save_interval=0, load_path=None, **network_kwargs): @@ -307,8 +307,14 @@ def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0 model.load(load_path) # Instantiate the runner object runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam) + if eval_env is not None: + eval_runner = Runner(env = eval_env, model = model, nsteps = nsteps, gamma = gamma, lam= lam) + + epinfobuf = deque(maxlen=100) + if eval_env is not None: + eval_epinfobuf = deque(maxlen=100) # Start total timer tfirststart = time.time() @@ -325,7 +331,12 @@ def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0 cliprangenow = cliprange(frac) # Get minibatch obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run() #pylint: disable=E0632 + if eval_env is not None: + eval_obs, eval_returns, eval_masks, eval_actions, eval_values, eval_neglogpacs, eval_states, eval_epinfos = eval_runner.run() #pylint: disable=E0632 + epinfobuf.extend(epinfos) + if eval_env is not None: + eval_epinfobuf.extend(eval_epinfos) # Here what we're going to do is for each minibatch calculate the loss and append it. mblossvals = [] @@ -375,6 +386,9 @@ def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0 logger.logkv("explained_variance", float(ev)) logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf])) logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf])) + if eval_env is not None: + logger.logkv('eval_eprewmean', safemean([epinfo['r'] for epinfo in eval_epinfobuf]) ) + logger.logkv('eval_eplenmean', safemean([epinfo['l'] for epinfo in eval_epinfobuf]) ) logger.logkv('time_elapsed', tnow - tfirststart) for (lossval, lossname) in zip(lossvals, model.loss_names): logger.logkv(lossname, lossval)