Skip to content

Commit

Permalink
FIX: release the parameter at some special case
Browse files Browse the repository at this point in the history
  • Loading branch information
MayDomine authored Jun 15, 2022
1 parent f2c046d commit 632b97a
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,13 @@ def exit(self):
for param in self.block._param_info:
kw_name = param["kw_name"]
param["parameter"].grad = None
dtype = self.block._storage_params[kw_name].dtype
device = self.block._storage_params[kw_name].device
if "begin" not in param:
param["parameter"].data = torch.tensor([], dtype=dtype, device=device)
continue
begin = param["begin"]
end = param["end"]
dtype = self.block._storage_params[kw_name].dtype
device = self.block._storage_params[kw_name].device
param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end)
if param["parameter"].requires_grad:
param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end)
Expand Down

0 comments on commit 632b97a

Please sign in to comment.