From 9cfbc5f66e011889fe8a30fd0f404ff685216cb0 Mon Sep 17 00:00:00 2001 From: davidmezzetti <561939+davidmezzetti@users.noreply.github.com> Date: Sat, 7 Dec 2024 07:08:59 -0500 Subject: [PATCH] Add remove_node method to SQLBackend --- grand/backends/_sqlbackend.py | 30 ++++++++++++++++++++++++++---- grand/backends/test_backends.py | 11 +++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/grand/backends/_sqlbackend.py b/grand/backends/_sqlbackend.py index 2f71493..6410de8 100644 --- a/grand/backends/_sqlbackend.py +++ b/grand/backends/_sqlbackend.py @@ -1,11 +1,10 @@ -from typing import Hashable, Generator, Optional, Iterable +from typing import Hashable, Generator import time import pandas as pd import sqlalchemy -from sqlalchemy.pool import NullPool -from sqlalchemy.sql import select -from sqlalchemy import and_, or_, func, Index +from sqlalchemy.sql import delete, select +from sqlalchemy import or_, func, Index from .backend import Backend @@ -192,6 +191,29 @@ def _upsert_node(self, node_name: Hashable, metadata: dict) -> Hashable: parameters={self._primary_key: node_name, "_metadata": metadata}, ) + def remove_node(self, name: Hashable) -> None: + """ + Removes nodes and related edges for name. + + Args: + node_name (Hashable): id of the node + """ + + # Remove nodes + statement = delete(self._node_table).where( + self._node_table.c[self._primary_key] == str(name) + ) + self._connection.execute(statement) + + # Remove edges for node + statement = delete(self._edge_table).where( + or_( + self._edge_table.c[self._edge_source_key] == str(name), + self._edge_table.c[self._edge_target_key] == str(name) + ) + ) + self._connection.execute(statement) + def all_nodes_as_iterable(self, include_metadata: bool = False) -> Generator: """ Get a generator of all of the nodes in this graph. diff --git a/grand/backends/test_backends.py b/grand/backends/test_backends.py index cb0d098..786e263 100644 --- a/grand/backends/test_backends.py +++ b/grand/backends/test_backends.py @@ -151,6 +151,17 @@ def test_sqlite_persistence(self): nodes = list(backend.all_nodes_as_iterable()) # assert assert node0 in nodes + + # test remove_node + backend = SQLBackend(db_url=url, directed=True) + node1, node2 = backend.add_node("A", {}), backend.add_node("B", {}) + backend.add_edge(node1, node2, {}) + assert backend.has_node(node1) + assert backend.has_edge(node1, node2) + backend.remove_node(node1) + assert not backend.has_node(node1) + assert not backend.has_edge(node1, node2) + # cleanup os.remove(dbpath)