diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index cff15ccc..561b0039 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -251,12 +251,13 @@ def exit(self): for param in self.block._param_info: kw_name = param["kw_name"] param["parameter"].grad = None + dtype = self.block._storage_params[kw_name].dtype + device = self.block._storage_params[kw_name].device if "begin" not in param: + param["parameter"].data = torch.tensor([], dtype=dtype, device=device) continue begin = param["begin"] end = param["end"] - dtype = self.block._storage_params[kw_name].dtype - device = self.block._storage_params[kw_name].device param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end) if param["parameter"].requires_grad: param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end)