Skip to content

Commit

Permalink
ppo2 now has eval stats (#120)
Browse files Browse the repository at this point in the history
* ppo2 now has eval stats

* fixed spaces

* fixed kwargs ordering

* whitespace fix
  • Loading branch information
xingyousong authored and pzhokhov committed Oct 3, 2018
1 parent 858afa8 commit e820b86
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion baselines/ppo2/ppo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def f(_):
return val
return f

def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4,
def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4,
vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95,
log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2,
save_interval=0, load_path=None, **network_kwargs):
Expand Down Expand Up @@ -307,8 +307,14 @@ def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0
model.load(load_path)
# Instantiate the runner object
runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam)
if eval_env is not None:
eval_runner = Runner(env = eval_env, model = model, nsteps = nsteps, gamma = gamma, lam= lam)



epinfobuf = deque(maxlen=100)
if eval_env is not None:
eval_epinfobuf = deque(maxlen=100)

# Start total timer
tfirststart = time.time()
Expand All @@ -325,7 +331,12 @@ def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0
cliprangenow = cliprange(frac)
# Get minibatch
obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run() #pylint: disable=E0632
if eval_env is not None:
eval_obs, eval_returns, eval_masks, eval_actions, eval_values, eval_neglogpacs, eval_states, eval_epinfos = eval_runner.run() #pylint: disable=E0632

epinfobuf.extend(epinfos)
if eval_env is not None:
eval_epinfobuf.extend(eval_epinfos)

# Here what we're going to do is for each minibatch calculate the loss and append it.
mblossvals = []
Expand Down Expand Up @@ -375,6 +386,9 @@ def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0
logger.logkv("explained_variance", float(ev))
logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf]))
logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf]))
if eval_env is not None:
logger.logkv('eval_eprewmean', safemean([epinfo['r'] for epinfo in eval_epinfobuf]) )
logger.logkv('eval_eplenmean', safemean([epinfo['l'] for epinfo in eval_epinfobuf]) )
logger.logkv('time_elapsed', tnow - tfirststart)
for (lossval, lossname) in zip(lossvals, model.loss_names):
logger.logkv(lossname, lossval)
Expand Down

0 comments on commit e820b86

Please sign in to comment.