diff --git a/grand/backends/_sqlbackend.py b/grand/backends/_sqlbackend.py index fea4ef1..6e68e7a 100644 --- a/grand/backends/_sqlbackend.py +++ b/grand/backends/_sqlbackend.py @@ -350,16 +350,16 @@ def get_node_by_id(self, node_name: Hashable): """ - res = ( - self._connection.execute( - self._node_table.select().where( - self._node_table.c[self._primary_key] == str(node_name) - ) + res = self._connection.execute( + self._node_table.select().where( + self._node_table.c[self._primary_key] == str(node_name) ) - .fetchone() - ._metadata - ) - return res + ).fetchone() + + if res: + return res._metadata + + raise KeyError(f"Node {node_name} not found") def get_edge_by_id(self, u: Hashable, v: Hashable): """ diff --git a/grand/backends/backend.py b/grand/backends/backend.py index 53aaa0f..b343437 100644 --- a/grand/backends/backend.py +++ b/grand/backends/backend.py @@ -339,6 +339,7 @@ class InMemoryCachedBackend(CachedBackend): "add_edge", "add_edges_from", "ingest_from_edgelist_dataframe", + "remove_node" ] _default_write_methods = [ @@ -347,6 +348,7 @@ class InMemoryCachedBackend(CachedBackend): "add_edge", "add_edges_from", "ingest_from_edgelist_dataframe", + "remove_node" ] def __init__( diff --git a/grand/backends/test_backends.py b/grand/backends/test_backends.py index 786e263..4e396e0 100644 --- a/grand/backends/test_backends.py +++ b/grand/backends/test_backends.py @@ -161,6 +161,8 @@ def test_sqlite_persistence(self): backend.remove_node(node1) assert not backend.has_node(node1) assert not backend.has_edge(node1, node2) + with pytest.raises(KeyError): + assert not backend.get_node_by_id(node1) # cleanup os.remove(dbpath)