Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 17, 2024
1 parent a8c6615 commit 4f41d4d
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 3 deletions.
3 changes: 1 addition & 2 deletions sota-implementations/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,11 @@ def main(cfg: "DictConfig"): # noqa: F821

# Create collector
collector = SyncDataCollector(
create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, "cpu"),
create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device),
policy=actor,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
compile_policy={"mode": compile_mode, "warmup": 1} if compile_mode else False,
cudagraph_policy=cfg.compile.cudagraphs,
Expand Down
1 change: 0 additions & 1 deletion sota-implementations/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def main(cfg: "DictConfig"): # noqa: F821
frames_per_batch=cfg.collector.frames_per_batch,
total_frames=cfg.collector.total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
compile_policy={"mode": compile_mode, "warmup": 1} if compile_mode else False,
cudagraph_policy=cfg.compile.cudagraphs,
Expand Down

0 comments on commit 4f41d4d

Please sign in to comment.