From 2fcd29db64ea1608cee47b710d90aed7330ac553 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Tue, 26 Mar 2024 17:35:24 +0100 Subject: [PATCH] Ensure backward compatibility in `MessagePassing` via `torch.load` (#9105) --- CHANGELOG.md | 1 + test/nn/conv/test_message_passing.py | 16 +++++++++++++++ torch_geometric/nn/conv/message_passing.py | 23 ++++++++++++---------- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b2c62218371..523118102a7b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Ensure backward compatibility in `MessagePassing` via `torch.load` ([#9105](https://github.com/pyg-team/pytorch_geometric/pull/9105)) - Prevent model compilation on custom `propagate` functions ([#9079](https://github.com/pyg-team/pytorch_geometric/pull/9079)) - Ignore `self.propagate` appearances in comments when parsing `MessagePassing` implementation ([#9044](https://github.com/pyg-team/pytorch_geometric/pull/9044)) - Fixed `OSError` on read-only file systems within `MessagePassing` ([#9032](https://github.com/pyg-team/pytorch_geometric/pull/9032)) diff --git a/test/nn/conv/test_message_passing.py b/test/nn/conv/test_message_passing.py index 901c7ac87ddd..2f2544e765fa 100644 --- a/test/nn/conv/test_message_passing.py +++ b/test/nn/conv/test_message_passing.py @@ -130,6 +130,22 @@ def test_my_conv_basic(): assert torch_adj_t.grad is not None +def test_my_conv_save(tmp_path): + conv = MyConv(8, 32) + assert conv._jinja_propagate is not None + assert conv.__class__._jinja_propagate is not None + assert conv._orig_propagate is not None + assert conv.__class__._orig_propagate is not None + + path = osp.join(tmp_path, 'model.pt') + torch.save(conv, path) + conv = torch.load(path) + assert conv._jinja_propagate is not None + assert conv.__class__._jinja_propagate is not None + assert conv._orig_propagate is not None + assert conv.__class__._orig_propagate is not None + + def test_my_conv_edge_index(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) diff --git a/torch_geometric/nn/conv/message_passing.py b/torch_geometric/nn/conv/message_passing.py index 0d9c1d2a291d..2d5b08c6875c 100644 --- a/torch_geometric/nn/conv/message_passing.py +++ b/torch_geometric/nn/conv/message_passing.py @@ -731,13 +731,14 @@ def decomposed_layers(self, decomposed_layers: int) -> None: self._decomposed_layers = decomposed_layers if decomposed_layers != 1: - self.propagate = self.__class__._orig_propagate.__get__( - self, MessagePassing) + if hasattr(self.__class__, '_orig_propagate'): + self.propagate = self.__class__._orig_propagate.__get__( + self, MessagePassing) - elif ((self.explain is None or self.explain is False) - and not self.propagate.__module__.endswith('_propagate')): - self.propagate = self.__class__._jinja_propagate.__get__( - self, MessagePassing) + elif self.explain is None or self.explain is False: + if hasattr(self.__class__, '_jinja_propagate'): + self.propagate = self.__class__._jinja_propagate.__get__( + self, MessagePassing) # Explainability ########################################################## @@ -761,16 +762,18 @@ def explain(self, explain: Optional[bool]) -> None: funcs=['message', 'explain_message', 'aggregate', 'update'], exclude=self.special_args, ) - self.propagate = self.__class__._orig_propagate.__get__( - self, MessagePassing) + if hasattr(self.__class__, '_orig_propagate'): + self.propagate = self.__class__._orig_propagate.__get__( + self, MessagePassing) else: self._user_args = self.inspector.get_flat_param_names( funcs=['message', 'aggregate', 'update'], exclude=self.special_args, ) if self.decomposed_layers == 1: - self.propagate = self.__class__._jinja_propagate.__get__( - self, MessagePassing) + if hasattr(self.__class__, '_jinja_propagate'): + self.propagate = self.__class__._jinja_propagate.__get__( + self, MessagePassing) def explain_message( self,