diff --git a/torchopt/optim/meta/base.py b/torchopt/optim/meta/base.py index 73ecdde7..fcc4542f 100644 --- a/torchopt/optim/meta/base.py +++ b/torchopt/optim/meta/base.py @@ -72,21 +72,22 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals ): flat_params: TupleOfTensors flat_params, container_treespec = pytree.tree_flatten_as_tuple(param_container) # type: ignore[arg-type] - if isinstance(state, UninitializedState): - state = self.impl.init(flat_params) grads = torch.autograd.grad( loss, flat_params, create_graph=True, allow_unused=True, ) - updates, new_state = self.impl.update( - grads, - state, - params=flat_params, - inplace=False, - ) - self.state_groups[i] = new_state + with torch.no_grad(): + if isinstance(state, UninitializedState): + state = self.impl.init(flat_params) + updates, new_state = self.impl.update( + grads, + state, + params=flat_params, + inplace=False, + ) + self.state_groups[i] = new_state flat_new_params = apply_updates(flat_params, updates, inplace=False) new_params: ModuleTensorContainers = pytree.tree_unflatten( # type: ignore[assignment] container_treespec,