diff --git a/torchrl/data/replay_buffers/scheduler.py b/torchrl/data/replay_buffers/scheduler.py index 3b98199db22..6829424c620 100644 --- a/torchrl/data/replay_buffers/scheduler.py +++ b/torchrl/data/replay_buffers/scheduler.py @@ -169,6 +169,8 @@ def __init__( self._delta = (self.final_val - self.initial_val) / self.num_steps def _step(self): + # Nit: we should use torch.where instead than if/else here to make the scheduler compatible with compile + # without graph breaks if self._step_cnt < self.num_steps: return self.initial_val + (self._delta * self._step_cnt) else: @@ -241,6 +243,8 @@ def _step(self): """Applies the scheduling logic to alter the parameter value every `n_steps`.""" # Check if the current step count is a multiple of n_steps current_val = getattr(self.sampler, self.param_name) + # Nit: we should use torch.where instead than if/else here to make the scheduler compatible with compile + # without graph breaks if self._step_cnt % self.n_steps == 0: return self.operator(current_val, self.gamma) else: