Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Sep 30, 2024
1 parent 915d1c4 commit 4b2897a
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torchrl/data/replay_buffers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ def __init__(
self._delta = (self.final_val - self.initial_val) / self.num_steps

def _step(self):
# Nit: we should use torch.where instead than if/else here to make the scheduler compatible with compile
# without graph breaks
if self._step_cnt < self.num_steps:
return self.initial_val + (self._delta * self._step_cnt)
else:
Expand Down Expand Up @@ -241,6 +243,8 @@ def _step(self):
"""Applies the scheduling logic to alter the parameter value every `n_steps`."""
# Check if the current step count is a multiple of n_steps
current_val = getattr(self.sampler, self.param_name)
# Nit: we should use torch.where instead than if/else here to make the scheduler compatible with compile
# without graph breaks
if self._step_cnt % self.n_steps == 0:
return self.operator(current_val, self.gamma)
else:
Expand Down

0 comments on commit 4b2897a

Please sign in to comment.