Skip to content

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
riship committed Jan 11, 2025
1 parent 976d175 commit b6d80b0
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torch_geometric/utils/rag/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,13 @@ class RemoteGraphBackendLoader:
graph_store_type: Type[ConvertableGraphStore]
feature_store_type: Type[ConvertableFeatureStore]

def load(self, pid: Optional[int] = None,
is_sorted=False) -> RemoteGraphBackend:
def load(self, pid: Optional[int] = None) -> RemoteGraphBackend:
if self.datatype == RemoteDataType.DATA:
data_obj = torch.load(self.path, weights_only=False)
# is_sorted=true since assume nodes come sorted from indexer
graph_store = self.graph_store_type.from_data(
edge_id=data_obj['edge_id'], edge_index=data_obj.edge_index,
num_nodes=data_obj.num_nodes, is_sorted=is_sorted)
num_nodes=data_obj.num_nodes, is_sorted=True)
graph_store.edge_index = graph_store.edge_index.contiguous()
feature_store = self.feature_store_type.from_data(
node_id=data_obj['node_id'], x=data_obj.x,
Expand Down Expand Up @@ -237,7 +237,7 @@ class to use. Defaults to LocalFeatureStore.
if n_parts == 1:
torch.save(data, path)
return RemoteGraphBackendLoader(path, RemoteDataType.DATA, graph_db,
feature_db, is_sorted=True)
feature_db)
else:
partitioner = Partitioner(data=data, num_parts=n_parts, root=path)
partitioner.generate_partition()
Expand Down

0 comments on commit b6d80b0

Please sign in to comment.