From c75c7192f8048f90c8150d071a86a051d5cbb4f5 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Tue, 26 Mar 2024 10:53:12 +0100 Subject: [PATCH] Allow `None` outputs in `FeatureStore` (#9102) Moves any `KeyError` logic to `_get_tensor()` --- CHANGELOG.md | 1 + torch_geometric/data/feature_store.py | 32 +++++++++--------------- torch_geometric/testing/feature_store.py | 2 +- 3 files changed, 14 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 36e106a7c77d..8b2c62218371 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Allow `None` outputs in `FeatureStore.get_tensor()` - `KeyError` should now be raised based on the implementation in `FeatureStore._get_tensor()` ([#9102](https://github.com/pyg-team/pytorch_geometric/pull/9102)) - Allow mini-batching of uncoalesced sparse matrices ([#9099](https://github.com/pyg-team/pytorch_geometric/pull/9099)) - Improvements to multi-node `ogbn-papers100m` default hyperparameters and adding evaluation on all ranks ([#8823](https://github.com/pyg-team/pytorch_geometric/pull/8823)) - Changed distributed sampler and loader tests to correctly report failures in subprocesses to `pytest` ([#8978](https://github.com/pyg-team/pytorch_geometric/pull/8978)) diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index e5a4a7973421..f1b5a2050eeb 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -28,6 +28,7 @@ import numpy as np import torch +from torch import Tensor from torch_geometric.typing import FeatureTensorType, NodeType from torch_geometric.utils.mixin import CastMixin @@ -329,8 +330,6 @@ def get_tensor( Raises: ValueError: If the input :class:`TensorAttr` is not fully specified. - KeyError: If the tensor corresponding to the input - :class:`TensorAttr` was not found. """ attr = self._tensor_attr_cls.cast(*args, **kwargs) if not attr.is_fully_specified(): @@ -339,9 +338,9 @@ def get_tensor( f"specifying all 'UNSET' fields.") tensor = self._get_tensor(attr) - if tensor is None: - raise KeyError(f"A tensor corresponding to '{attr}' was not found") - return self._to_type(attr, tensor) if convert_type else tensor + if convert_type: + tensor = self._to_type(attr, tensor) + return tensor def _multi_get_tensor( self, @@ -375,8 +374,6 @@ def multi_get_tensor( Raises: ValueError: If any input :class:`TensorAttr` is not fully specified. - KeyError: If any of the tensors corresponding to the input - :class:`TensorAttr` was not found. """ attrs = [self._tensor_attr_cls.cast(attr) for attr in attrs] bad_attrs = [attr for attr in attrs if not attr.is_fully_specified()] @@ -387,15 +384,12 @@ def multi_get_tensor( f"'UNSET' fields") tensors = self._multi_get_tensor(attrs) - if any(v is None for v in tensors): - bad_attrs = [attrs[i] for i, v in enumerate(tensors) if v is None] - raise KeyError(f"Tensors corresponding to attributes " - f"'{bad_attrs}' were not found") - - return [ - self._to_type(attr, tensor) if convert_type else tensor - for attr, tensor in zip(attrs, tensors) - ] + if convert_type: + tensors = [ + self._to_type(attr, tensor) + for attr, tensor in zip(attrs, tensors) + ] + return tensors @abstractmethod def _remove_tensor(self, attr: TensorAttr) -> bool: @@ -476,11 +470,9 @@ def _to_type( attr: TensorAttr, tensor: FeatureTensorType, ) -> FeatureTensorType: - if (isinstance(attr.index, torch.Tensor) - and isinstance(tensor, np.ndarray)): + if isinstance(attr.index, Tensor) and isinstance(tensor, np.ndarray): return torch.from_numpy(tensor) - if (isinstance(attr.index, np.ndarray) - and isinstance(tensor, torch.Tensor)): + if isinstance(attr.index, np.ndarray) and isinstance(tensor, Tensor): return tensor.detach().cpu().numpy() return tensor diff --git a/torch_geometric/testing/feature_store.py b/torch_geometric/testing/feature_store.py index 03b61f5eef48..c317a8afeb69 100644 --- a/torch_geometric/testing/feature_store.py +++ b/torch_geometric/testing/feature_store.py @@ -36,7 +36,7 @@ def _get_tensor(self, attr: TensorAttr) -> Optional[Tensor]: index, tensor = self.store.get(self.key(attr), (None, None)) if tensor is None: - return None + raise KeyError(f"Could not find tensor for '{attr}'") assert isinstance(tensor, Tensor)