Skip to content

Commit

Permalink
Ensure backward compatibility in MessagePassing via torch.load (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Mar 26, 2024
1 parent c75c719 commit 2fcd29d
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Ensure backward compatibility in `MessagePassing` via `torch.load` ([#9105](https://github.com/pyg-team/pytorch_geometric/pull/9105))
- Prevent model compilation on custom `propagate` functions ([#9079](https://github.com/pyg-team/pytorch_geometric/pull/9079))
- Ignore `self.propagate` appearances in comments when parsing `MessagePassing` implementation ([#9044](https://github.com/pyg-team/pytorch_geometric/pull/9044))
- Fixed `OSError` on read-only file systems within `MessagePassing` ([#9032](https://github.com/pyg-team/pytorch_geometric/pull/9032))
Expand Down
16 changes: 16 additions & 0 deletions test/nn/conv/test_message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,22 @@ def test_my_conv_basic():
assert torch_adj_t.grad is not None


def test_my_conv_save(tmp_path):
conv = MyConv(8, 32)
assert conv._jinja_propagate is not None
assert conv.__class__._jinja_propagate is not None
assert conv._orig_propagate is not None
assert conv.__class__._orig_propagate is not None

path = osp.join(tmp_path, 'model.pt')
torch.save(conv, path)
conv = torch.load(path)
assert conv._jinja_propagate is not None
assert conv.__class__._jinja_propagate is not None
assert conv._orig_propagate is not None
assert conv.__class__._orig_propagate is not None


def test_my_conv_edge_index():
x = torch.randn(4, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
Expand Down
23 changes: 13 additions & 10 deletions torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,13 +731,14 @@ def decomposed_layers(self, decomposed_layers: int) -> None:
self._decomposed_layers = decomposed_layers

if decomposed_layers != 1:
self.propagate = self.__class__._orig_propagate.__get__(
self, MessagePassing)
if hasattr(self.__class__, '_orig_propagate'):
self.propagate = self.__class__._orig_propagate.__get__(
self, MessagePassing)

elif ((self.explain is None or self.explain is False)
and not self.propagate.__module__.endswith('_propagate')):
self.propagate = self.__class__._jinja_propagate.__get__(
self, MessagePassing)
elif self.explain is None or self.explain is False:
if hasattr(self.__class__, '_jinja_propagate'):
self.propagate = self.__class__._jinja_propagate.__get__(
self, MessagePassing)

# Explainability ##########################################################

Expand All @@ -761,16 +762,18 @@ def explain(self, explain: Optional[bool]) -> None:
funcs=['message', 'explain_message', 'aggregate', 'update'],
exclude=self.special_args,
)
self.propagate = self.__class__._orig_propagate.__get__(
self, MessagePassing)
if hasattr(self.__class__, '_orig_propagate'):
self.propagate = self.__class__._orig_propagate.__get__(
self, MessagePassing)
else:
self._user_args = self.inspector.get_flat_param_names(
funcs=['message', 'aggregate', 'update'],
exclude=self.special_args,
)
if self.decomposed_layers == 1:
self.propagate = self.__class__._jinja_propagate.__get__(
self, MessagePassing)
if hasattr(self.__class__, '_jinja_propagate'):
self.propagate = self.__class__._jinja_propagate.__get__(
self, MessagePassing)

def explain_message(
self,
Expand Down

0 comments on commit 2fcd29d

Please sign in to comment.