Skip to content

Commit

Permalink
Merge pull request #32 from MayDomine/zero2_fixed
Browse files Browse the repository at this point in the history
FIX: release the parameter at some special case
  • Loading branch information
a710128 authored Jun 15, 2022
2 parents f2c046d + 632b97a commit 5b96863
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,13 @@ def exit(self):
for param in self.block._param_info:
kw_name = param["kw_name"]
param["parameter"].grad = None
dtype = self.block._storage_params[kw_name].dtype
device = self.block._storage_params[kw_name].device
if "begin" not in param:
param["parameter"].data = torch.tensor([], dtype=dtype, device=device)
continue
begin = param["begin"]
end = param["end"]
dtype = self.block._storage_params[kw_name].dtype
device = self.block._storage_params[kw_name].device
param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end)
if param["parameter"].requires_grad:
param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end)
Expand Down

0 comments on commit 5b96863

Please sign in to comment.