Skip to content

Commit

Permalink
fix torch tensor memory not release due to gradient link
Browse files Browse the repository at this point in the history
  • Loading branch information
yuchuang committed May 21, 2024
1 parent b3f570c commit d0132cb
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions torchopt/optim/meta/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,22 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals
):
flat_params: TupleOfTensors
flat_params, container_treespec = pytree.tree_flatten_as_tuple(param_container) # type: ignore[arg-type]
if isinstance(state, UninitializedState):
state = self.impl.init(flat_params)
grads = torch.autograd.grad(
loss,
flat_params,
create_graph=True,
allow_unused=True,
)
updates, new_state = self.impl.update(
grads,
state,
params=flat_params,
inplace=False,
)
self.state_groups[i] = new_state
with torch.no_grad():
if isinstance(state, UninitializedState):
state = self.impl.init(flat_params)
updates, new_state = self.impl.update(
grads,
state,
params=flat_params,
inplace=False,
)
self.state_groups[i] = new_state
flat_new_params = apply_updates(flat_params, updates, inplace=False)
new_params: ModuleTensorContainers = pytree.tree_unflatten( # type: ignore[assignment]
container_treespec,
Expand Down

0 comments on commit d0132cb

Please sign in to comment.