Skip to content

Commit

Permalink
removing seed nodes from the ragquerylaoder and feature store since t…
Browse files Browse the repository at this point in the history
…heir unused and cause wierdness
  • Loading branch information
riship committed Jan 9, 2025
1 parent ec71020 commit 53572aa
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
7 changes: 4 additions & 3 deletions torch_geometric/loader/rag_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,12 @@ def query(self, query: Any) -> Data:
"""
seed_nodes = self.feature_store.retrieve_seed_nodes(
query, **self.seed_nodes_kwargs)
seed_edges = self.feature_store.retrieve_seed_edges(
query, **self.seed_edges_kwargs)
# Graph Store does not Use These, save computation
# seed_edges = self.feature_store.retrieve_seed_edges(
# query, **self.seed_edges_kwargs)

subgraph_sample = self.graph_store.sample_subgraph(
seed_nodes, seed_edges, **self.sampler_kwargs)
seed_nodes, **self.sampler_kwargs)

data = self.feature_store.load_subgraph(sample=subgraph_sample,
**self.loader_kwargs)
Expand Down
10 changes: 6 additions & 4 deletions torch_geometric/utils/rag/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,16 @@ def num_neighbors(self, num_neighbors: NumNeighborsType):
self.sampler.num_neighbors = num_neighbors

def sample_subgraph(
self, seed_nodes: InputNodes, seed_edges: InputEdges,
self, seed_nodes: InputNodes, seed_edges: Optional[InputEdges] = None,
num_neighbors: Optional[NumNeighborsType] = None
) -> Union[SamplerOutput, HeteroSamplerOutput]:
"""Sample the graph starting from the given nodes and edges using the
in-built NeighborSampler.
Args:
seed_nodes (InputNodes): Seed nodes to start sampling from.
seed_edges (InputEdges): Seed edges to start sampling from.
seed_edges (Optional[InputEdges], optional): Seed edges to start sampling from.
Defaults to None.
num_neighbors (Optional[NumNeighborsType], optional): Parameters
to determine how many hops and number of neighbors per hop.
Defaults to None.
Expand All @@ -89,8 +90,9 @@ def sample_subgraph(
# FIXME: Right now, only input nodes/edges as tensors are be supported
if not isinstance(seed_nodes, Tensor):
raise NotImplementedError
if not isinstance(seed_edges, Tensor):
raise NotImplementedError
if seed_edges:
if not isinstance(seed_edges, Tensor):
raise NotImplementedError
seed_nodes.device

# TODO: Call sample_from_edges for seed_edges
Expand Down

0 comments on commit 53572aa

Please sign in to comment.