Skip to content

Commit

Permalink
Allow None outputs in FeatureStore (#9102)
Browse files Browse the repository at this point in the history
Moves any `KeyError` logic to `_get_tensor()`
  • Loading branch information
rusty1s authored Mar 26, 2024
1 parent ceeea03 commit c75c719
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Allow `None` outputs in `FeatureStore.get_tensor()` - `KeyError` should now be raised based on the implementation in `FeatureStore._get_tensor()` ([#9102](https://github.com/pyg-team/pytorch_geometric/pull/9102))
- Allow mini-batching of uncoalesced sparse matrices ([#9099](https://github.com/pyg-team/pytorch_geometric/pull/9099))
- Improvements to multi-node `ogbn-papers100m` default hyperparameters and adding evaluation on all ranks ([#8823](https://github.com/pyg-team/pytorch_geometric/pull/8823))
- Changed distributed sampler and loader tests to correctly report failures in subprocesses to `pytest` ([#8978](https://github.com/pyg-team/pytorch_geometric/pull/8978))
Expand Down
32 changes: 12 additions & 20 deletions torch_geometric/data/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import numpy as np
import torch
from torch import Tensor

from torch_geometric.typing import FeatureTensorType, NodeType
from torch_geometric.utils.mixin import CastMixin
Expand Down Expand Up @@ -329,8 +330,6 @@ def get_tensor(
Raises:
ValueError: If the input :class:`TensorAttr` is not fully
specified.
KeyError: If the tensor corresponding to the input
:class:`TensorAttr` was not found.
"""
attr = self._tensor_attr_cls.cast(*args, **kwargs)
if not attr.is_fully_specified():
Expand All @@ -339,9 +338,9 @@ def get_tensor(
f"specifying all 'UNSET' fields.")

tensor = self._get_tensor(attr)
if tensor is None:
raise KeyError(f"A tensor corresponding to '{attr}' was not found")
return self._to_type(attr, tensor) if convert_type else tensor
if convert_type:
tensor = self._to_type(attr, tensor)
return tensor

def _multi_get_tensor(
self,
Expand Down Expand Up @@ -375,8 +374,6 @@ def multi_get_tensor(
Raises:
ValueError: If any input :class:`TensorAttr` is not fully
specified.
KeyError: If any of the tensors corresponding to the input
:class:`TensorAttr` was not found.
"""
attrs = [self._tensor_attr_cls.cast(attr) for attr in attrs]
bad_attrs = [attr for attr in attrs if not attr.is_fully_specified()]
Expand All @@ -387,15 +384,12 @@ def multi_get_tensor(
f"'UNSET' fields")

tensors = self._multi_get_tensor(attrs)
if any(v is None for v in tensors):
bad_attrs = [attrs[i] for i, v in enumerate(tensors) if v is None]
raise KeyError(f"Tensors corresponding to attributes "
f"'{bad_attrs}' were not found")

return [
self._to_type(attr, tensor) if convert_type else tensor
for attr, tensor in zip(attrs, tensors)
]
if convert_type:
tensors = [
self._to_type(attr, tensor)
for attr, tensor in zip(attrs, tensors)
]
return tensors

@abstractmethod
def _remove_tensor(self, attr: TensorAttr) -> bool:
Expand Down Expand Up @@ -476,11 +470,9 @@ def _to_type(
attr: TensorAttr,
tensor: FeatureTensorType,
) -> FeatureTensorType:
if (isinstance(attr.index, torch.Tensor)
and isinstance(tensor, np.ndarray)):
if isinstance(attr.index, Tensor) and isinstance(tensor, np.ndarray):
return torch.from_numpy(tensor)
if (isinstance(attr.index, np.ndarray)
and isinstance(tensor, torch.Tensor)):
if isinstance(attr.index, np.ndarray) and isinstance(tensor, Tensor):
return tensor.detach().cpu().numpy()
return tensor

Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/testing/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _get_tensor(self, attr: TensorAttr) -> Optional[Tensor]:
index, tensor = self.store.get(self.key(attr), (None, None))

if tensor is None:
return None
raise KeyError(f"Could not find tensor for '{attr}'")

assert isinstance(tensor, Tensor)

Expand Down

0 comments on commit c75c719

Please sign in to comment.