From 632b97a5b37691d2bad0381b9691900928a917d2 Mon Sep 17 00:00:00 2001 From: Bojack <57244158+MayDomine@users.noreply.github.com> Date: Wed, 15 Jun 2022 12:52:33 +0800 Subject: [PATCH] FIX: release the parameter at some special case --- bmtrain/block_layer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index cff15ccc..561b0039 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -251,12 +251,13 @@ def exit(self): for param in self.block._param_info: kw_name = param["kw_name"] param["parameter"].grad = None + dtype = self.block._storage_params[kw_name].dtype + device = self.block._storage_params[kw_name].device if "begin" not in param: + param["parameter"].data = torch.tensor([], dtype=dtype, device=device) continue begin = param["begin"] end = param["end"] - dtype = self.block._storage_params[kw_name].dtype - device = self.block._storage_params[kw_name].device param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end) if param["parameter"].requires_grad: param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end)