Skip to content

Commit

Permalink
Fix load_from_state_dict in lazy Linear modules (#8242)
Browse files Browse the repository at this point in the history
Fixes #8234 and
#8229.
  • Loading branch information
rusty1s authored Oct 21, 2023
1 parent 9632694 commit 114ddca
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed Pre-trained `DimeNet++` performance on QM9 ([#8239](https://github.com/pyg-team/pytorch_geometric/pull/8239))
- Fixed `load_from_state_dict` in lazy `Linear` modules ([#8242](https://github.com/pyg-team/pytorch_geometric/pull/8242))
- Fixed pre-trained `DimeNet++` performance on `QM9` ([#8239](https://github.com/pyg-team/pytorch_geometric/pull/8239))
- Fixed `GNNExplainer` usage within `AttentiveFP` ([#8216](https://github.com/pyg-team/pytorch_geometric/pull/8216))
- Fixed `to_networkx(to_undirected=True)` in case the input graph is not undirected ([#8204](https://github.com/pyg-team/pytorch_geometric/pull/8204))
- Fixed sparse-sparse matrix multiplication support on Windows in `TwoHop` and `AddRandomWalkPE` transformations ([#8197](https://github.com/pyg-team/pytorch_geometric/pull/8197), [#8225](https://github.com/pyg-team/pytorch_geometric/pull/8225))
Expand Down
18 changes: 14 additions & 4 deletions test/nn/dense/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,16 @@ def test_lazy_linear(weight, bias, device):
@withCUDA
@pytest.mark.parametrize('dim1', [-1, 16])
@pytest.mark.parametrize('dim2', [-1, 16])
def test_load_lazy_linear(dim1, dim2, device):
lin1 = Linear(dim1, 32).to(device)
lin2 = Linear(dim1, 32).to(device)
@pytest.mark.parametrize('bias', [True, False])
def test_load_lazy_linear(dim1, dim2, bias, device):
lin1 = Linear(dim1, 32, bias=bias).to(device)
lin2 = Linear(dim2, 32, bias=bias).to(device)
lin2.load_state_dict(lin1.state_dict())

if dim1 != -1:
assert isinstance(lin1.weight, torch.nn.Parameter)
assert isinstance(lin2.weight, torch.nn.Parameter)
assert torch.allclose(lin1.weight, lin2.weight)
assert torch.allclose(lin1.bias, lin2.bias)
assert not hasattr(lin1, '_hook')
assert not hasattr(lin2, '_hook')
else:
Expand All @@ -61,6 +63,14 @@ def test_load_lazy_linear(dim1, dim2, device):
assert hasattr(lin1, '_hook')
assert hasattr(lin2, '_hook')

if bias:
assert isinstance(lin1.bias, torch.nn.Parameter)
assert isinstance(lin2.bias, torch.nn.Parameter)
assert torch.allclose(lin1.bias, lin2.bias)
else:
assert lin1.bias is None
assert lin2.bias is None

with pytest.raises(RuntimeError, match="in state_dict"):
lin1.load_state_dict({}, strict=True)
lin1.load_state_dict({}, strict=False)
Expand Down
19 changes: 19 additions & 0 deletions torch_geometric/nn/dense/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,25 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
else:
destination[prefix + 'bias'] = self.bias.detach()

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
weight = state_dict.get(prefix + 'weight', None)

if weight is not None and is_uninitialized_parameter(weight):
self.in_channels = -1
self.weight = torch.nn.parameter.UninitializedParameter()
if not hasattr(self, '_hook'):
self._hook = self.register_forward_pre_hook(
self.initialize_parameters)

elif weight is not None and is_uninitialized_parameter(self.weight):
self.in_channels = weight.size(-1)
self.weight.materialize((self.out_channels, self.in_channels))
if hasattr(self, '_hook'):
self._hook.remove()
delattr(self, '_hook')

super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, bias={self.bias is not None})')
Expand Down

0 comments on commit 114ddca

Please sign in to comment.