Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 18, 2024
1 parent 595808e commit ace76f3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
9 changes: 5 additions & 4 deletions sota-implementations/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,11 @@ def update(batch, num_network_updates):
data_buffer.extend(data_reshape)

for k, batch in enumerate(data_buffer):
torch.compiler.cudagraph_mark_step_begin()
loss, num_network_updates = update(
batch, num_network_updates=num_network_updates
)
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
loss, num_network_updates = update(
batch, num_network_updates=num_network_updates
)
loss = loss.clone()
num_network_updates = num_network_updates.clone()
losses[j, k] = loss.select(
Expand Down
11 changes: 6 additions & 5 deletions sota-implementations/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,12 @@ def update(batch, num_network_updates):
data_buffer.extend(data_reshape)

for k, batch in enumerate(data_buffer):
torch.compiler.cudagraph_mark_step_begin()
loss, num_network_updates = update(
batch, num_network_updates=num_network_updates
)
loss = loss.clone()
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
loss, num_network_updates = update(
batch, num_network_updates=num_network_updates
)
loss = loss.clone()
num_network_updates = num_network_updates.clone()
losses[j, k] = loss.select(
"loss_critic", "loss_entropy", "loss_objective"
Expand Down

0 comments on commit ace76f3

Please sign in to comment.