Skip to content

Commit

Permalink
Add remove_node method to SQLBackend
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Dec 7, 2024
1 parent 71286f3 commit 9cfbc5f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
30 changes: 26 additions & 4 deletions grand/backends/_sqlbackend.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Hashable, Generator, Optional, Iterable
from typing import Hashable, Generator
import time

import pandas as pd
import sqlalchemy
from sqlalchemy.pool import NullPool
from sqlalchemy.sql import select
from sqlalchemy import and_, or_, func, Index
from sqlalchemy.sql import delete, select
from sqlalchemy import or_, func, Index

from .backend import Backend

Expand Down Expand Up @@ -192,6 +191,29 @@ def _upsert_node(self, node_name: Hashable, metadata: dict) -> Hashable:
parameters={self._primary_key: node_name, "_metadata": metadata},
)

def remove_node(self, name: Hashable) -> None:
"""
Removes nodes and related edges for name.
Args:
node_name (Hashable): id of the node
"""

# Remove nodes
statement = delete(self._node_table).where(
self._node_table.c[self._primary_key] == str(name)
)
self._connection.execute(statement)

# Remove edges for node
statement = delete(self._edge_table).where(
or_(
self._edge_table.c[self._edge_source_key] == str(name),
self._edge_table.c[self._edge_target_key] == str(name)
)
)
self._connection.execute(statement)

def all_nodes_as_iterable(self, include_metadata: bool = False) -> Generator:
"""
Get a generator of all of the nodes in this graph.
Expand Down
11 changes: 11 additions & 0 deletions grand/backends/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,17 @@ def test_sqlite_persistence(self):
nodes = list(backend.all_nodes_as_iterable())
# assert
assert node0 in nodes

# test remove_node
backend = SQLBackend(db_url=url, directed=True)
node1, node2 = backend.add_node("A", {}), backend.add_node("B", {})
backend.add_edge(node1, node2, {})
assert backend.has_node(node1)
assert backend.has_edge(node1, node2)
backend.remove_node(node1)
assert not backend.has_node(node1)
assert not backend.has_edge(node1, node2)

# cleanup
os.remove(dbpath)

Expand Down

0 comments on commit 9cfbc5f

Please sign in to comment.