Skip to content

Commit

Permalink
Fix for new pyg_lib.neighbor_sample argument order (#8381)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Nov 15, 2023
1 parent 31b386b commit 987d767
Show file tree
Hide file tree
Showing 11 changed files with 18 additions and 6 deletions.
1 change: 1 addition & 0 deletions .github/workflows/full_gpu_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
pip install -e .[full,test]
- name: Run tests
timeout-minutes: 20
run: |
FULL_TEST=1 pytest
shell: bash
1 change: 1 addition & 0 deletions .github/workflows/full_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ jobs:
pip install -e .[full,test]
- name: Run tests
timeout-minutes: 20
run: |
FULL_TEST=1 pytest --cov --cov-report=xml
shell: bash
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/latest_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,6 @@ jobs:
- name: Run tests
if: steps.changed-files-specific.outputs.only_changed != 'true'
timeout-minutes: 10
run: |
pytest
1 change: 1 addition & 0 deletions .github/workflows/minimal_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,6 @@ jobs:
- name: Run tests
if: steps.changed-files-specific.outputs.only_changed != 'true'
timeout-minutes: 10
run: |
pytest
1 change: 1 addition & 0 deletions .github/workflows/prev_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,6 @@ jobs:
- name: Run tests
if: steps.changed-files-specific.outputs.only_changed != 'true'
timeout-minutes: 10
run: |
pytest
1 change: 1 addition & 0 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ jobs:
- name: Run tests
if: steps.changed-files-specific.outputs.only_changed != 'true'
timeout-minutes: 10
run: |
pytest --cov --cov-report=xml --durations 10
Expand Down
6 changes: 3 additions & 3 deletions test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def test_pyg_lib_and_torch_sparse_homo_equality():
seed = torch.arange(10)

sample = torch.ops.pyg.neighbor_sample
out1 = sample(colptr, row, seed, [-1, -1], None, None, None, True)
out1 = sample(colptr, row, seed, [-1, -1], None, None, None, None, True)
sample = torch.ops.torch_sparse.neighbor_sample
out2 = sample(colptr, row, seed, [-1, -1], False, True)

Expand Down Expand Up @@ -570,8 +570,8 @@ def test_pyg_lib_and_torch_sparse_hetero_equality():

sample = torch.ops.pyg.hetero_neighbor_sample
out1 = sample(node_types, edge_types, colptr_dict, row_dict, seed_dict,
num_neighbors_dict, None, None, None, True, False, True,
False, "uniform", True)
num_neighbors_dict, None, None, None, None, True, False,
True, False, "uniform", True)
sample = torch.ops.torch_sparse.hetero_neighbor_sample
out2 = sample(node_types, edge_types, colptr_dict, row_dict, seed_dict,
num_neighbors_dict, 2, False, True)
Expand Down
1 change: 0 additions & 1 deletion test/test_config_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,3 @@ def test_hydra_config_store():
assert cfg.lr_scheduler.cooldown == 0
assert cfg.lr_scheduler.min_lr == 0
assert cfg.lr_scheduler.eps == 1e-08
assert not cfg.lr_scheduler.verbose
1 change: 1 addition & 0 deletions torch_geometric/distributed/dist_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,7 @@ def _sample_one_hop(
input_nodes.to(colptr.dtype),
num_neighbors,
node_time,
None, # edge_time
seed_time,
None, # TODO: edge_weight
True, # csc
Expand Down
8 changes: 6 additions & 2 deletions torch_geometric/sampler/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,10 @@ def _sample(
seed,
self.num_neighbors.get_mapped_values(self.edge_types),
self.node_time,
seed_time,
)
if torch_geometric.typing.WITH_EDGE_TIME_NEIGHBOR_SAMPLE:
args += (None, )
args += (seed_time, )
if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE:
args += (self.edge_weight, )
args += (
Expand Down Expand Up @@ -380,8 +382,10 @@ def _sample(
seed.to(self.colptr.dtype),
self.num_neighbors.get_mapped_values(),
self.node_time,
seed_time,
)
if torch_geometric.typing.WITH_EDGE_TIME_NEIGHBOR_SAMPLE:
args += (None, )
args += (seed_time, )
if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE:
args += (self.edge_weight, )
args += (
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
WITH_SAMPLED_OP = hasattr(pyg_lib.ops, 'sampled_add')
WITH_INDEX_SORT = hasattr(pyg_lib.ops, 'index_sort')
WITH_METIS = hasattr(pyg_lib, 'partition')
WITH_EDGE_TIME_NEIGHBOR_SAMPLE = ('edge_time' in inspect.signature(
pyg_lib.sampler.neighbor_sample).parameters)
WITH_WEIGHTED_NEIGHBOR_SAMPLE = ('edge_weight' in inspect.signature(
pyg_lib.sampler.neighbor_sample).parameters)
except Exception as e:
Expand Down

0 comments on commit 987d767

Please sign in to comment.