Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Oct 8, 2024
1 parent 2315434 commit 0d1ac6a
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 8 deletions.
4 changes: 1 addition & 3 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ sphinx_design
torchvision
dm_control
mujoco
atari-py
ale-py
gym[classic_control,accept-rom-license]
gym[classic_control,accept-rom-license,ale-py,atari]
pygame
tqdm
ipython
Expand Down
1 change: 0 additions & 1 deletion sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,6 @@ def make_collector_offpolicy(
"init_random_frames": cfg.collector.init_random_frames,
"split_trajs": True,
# trajectories must be separated if multi-step is used
"exploration_type": cfg.collector.exploration_type,
}

collector = collector_helper(**collector_helper_kwargs)
Expand Down
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/coding_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ def make_recorder(actor_model_explore, transform_state_dict, record_interval):
record_frames=1000,
policy_exploration=actor_model_explore,
environment=environment,
exploration_type=ExplorationType.MEAN,
exploration_type=ExplorationType.DETERMINISTIC,
record_interval=record_interval,
)
return recorder_obj
Expand Down
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/coding_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@
# number of steps (1000, which is our ``env`` horizon).
# The ``rollout`` method of the ``env`` can take a policy as argument:
# it will then execute this policy at each step.
with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
# execute a rollout with the trained policy
eval_rollout = env.rollout(1000, policy_module)
logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
Expand Down
4 changes: 2 additions & 2 deletions tutorials/sphinx-tutorials/getting-started-1.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@

from torchrl.envs.utils import ExplorationType, set_exploration_type

with set_exploration_type(ExplorationType.MEAN):
with set_exploration_type(ExplorationType.DETERMINISTIC):
# takes the mean as action
rollout = env.rollout(max_steps=10, policy=policy)
with set_exploration_type(ExplorationType.RANDOM):
Expand Down Expand Up @@ -221,7 +221,7 @@

exploration_policy = TensorDictSequential(policy, exploration_module)

with set_exploration_type(ExplorationType.MEAN):
with set_exploration_type(ExplorationType.DETERMINISTIC):
# Turns off exploration
rollout = env.rollout(max_steps=10, policy=exploration_policy)
with set_exploration_type(ExplorationType.RANDOM):
Expand Down

0 comments on commit 0d1ac6a

Please sign in to comment.