From 046b95ab1e93252a99cfd9b9956afacf40edda02 Mon Sep 17 00:00:00 2001 From: David Mezzetti <561939+davidmezzetti@users.noreply.github.com> Date: Thu, 11 Apr 2024 16:43:30 -0400 Subject: [PATCH] Update SQL backend to support SQLAlchemy 2.x (#45) * Update SQL backend to support SQLAlchemy 2.x, add batch inserts for nodes/edges, add edge index * Fix unit test errors * Update Python versions for build script --- .github/workflows/python-package.yml | 2 +- grand/backends/_sqlbackend.py | 217 +++++++++++++++------------ grand/backends/backend.py | 26 ++++ grand/dialects/__init__.py | 6 + 4 files changed, 156 insertions(+), 95 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 9927ae8..0736981 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.7, 3.8, 3.9, '3.10'] + python-version: [3.8, 3.9, '3.10', '3.11'] steps: - uses: actions/checkout@v2 diff --git a/grand/backends/_sqlbackend.py b/grand/backends/_sqlbackend.py index 0630948..99996d4 100644 --- a/grand/backends/_sqlbackend.py +++ b/grand/backends/_sqlbackend.py @@ -5,7 +5,7 @@ import sqlalchemy from sqlalchemy.pool import NullPool from sqlalchemy.sql import select -from sqlalchemy import and_, or_, func +from sqlalchemy import and_, or_, func, Index from .backend import Backend @@ -60,51 +60,48 @@ def __init__( self._connection = self._engine.connect() self._metadata = sqlalchemy.MetaData() - if not self._engine.dialect.has_table(self._connection, self._node_table_name): - self._node_table = sqlalchemy.Table( - self._node_table_name, - self._metadata, - sqlalchemy.Column( - self._primary_key, - sqlalchemy.String(_DEFAULT_SQL_STR_LEN), - primary_key=True, - ), - sqlalchemy.Column("_metadata", sqlalchemy.JSON), - ) - self._node_table.create(self._engine) - else: - self._node_table = sqlalchemy.Table( - self._node_table_name, - self._metadata, - autoload=True, - autoload_with=self._engine, - ) + # Create nodes table + self._node_table = sqlalchemy.Table( + self._node_table_name, + self._metadata, + sqlalchemy.Column( + self._primary_key, + sqlalchemy.String(_DEFAULT_SQL_STR_LEN), + primary_key=True, + ), + sqlalchemy.Column("_metadata", sqlalchemy.JSON), + ) + self._node_table.create(self._engine, checkfirst=True) - if not self._engine.dialect.has_table(self._connection, self._edge_table_name): - self._edge_table = sqlalchemy.Table( - self._edge_table_name, - self._metadata, - sqlalchemy.Column( - self._primary_key, - sqlalchemy.String(_DEFAULT_SQL_STR_LEN), - primary_key=True, - ), - sqlalchemy.Column("_metadata", sqlalchemy.JSON), - sqlalchemy.Column( - self._edge_source_key, sqlalchemy.String(_DEFAULT_SQL_STR_LEN) - ), - sqlalchemy.Column( - self._edge_target_key, sqlalchemy.String(_DEFAULT_SQL_STR_LEN) - ), - ) - self._edge_table.create(self._engine) - else: - self._edge_table = sqlalchemy.Table( - self._edge_table_name, - self._metadata, - autoload=True, - autoload_with=self._engine, - ) + source_column = sqlalchemy.Column( + self._edge_source_key, sqlalchemy.String(_DEFAULT_SQL_STR_LEN) + ) + + target_column = sqlalchemy.Column( + self._edge_target_key, sqlalchemy.String(_DEFAULT_SQL_STR_LEN) + ) + + # Create edges table + self._edge_table = sqlalchemy.Table( + self._edge_table_name, + self._metadata, + sqlalchemy.Column( + self._primary_key, + sqlalchemy.String(_DEFAULT_SQL_STR_LEN), + primary_key=True, + ), + sqlalchemy.Column("_metadata", sqlalchemy.JSON), + source_column, + target_column + ) + self._edge_table.create(self._engine, checkfirst=True) + + # Create source and target index + sindex = Index("edge_source", source_column) + sindex.create(self._engine, checkfirst=True) + + tindex = Index("edge_target", target_column) + tindex.create(self._engine, checkfirst=True) def is_directed(self) -> bool: """ @@ -147,17 +144,25 @@ def add_node(self, node_name: Hashable, metadata: dict) -> Hashable: existing_metadata.update(metadata) self._connection.execute( self._node_table.update().where( - self._node_table.c[self._primary_key] == node_name + self._node_table.c[self._primary_key] == str(node_name) ), - **{"_metadata": existing_metadata}, + parameters={"_metadata": existing_metadata}, ) else: self._connection.execute( self._node_table.insert(), - **{self._primary_key: node_name, "_metadata": metadata}, + parameters={self._primary_key: node_name, "_metadata": metadata}, ) return node_name + def add_nodes_from(self, nodes_for_adding, **attr): + nodes = [{ + self._primary_key: node, + "_metadata": {**attr, **metadata}, + } for node, metadata in nodes_for_adding] + + self._connection.execute(self._node_table.insert(), nodes) + def _upsert_node(self, node_name: Hashable, metadata: dict) -> Hashable: """ Add a new node to the graph, or update an existing one. @@ -174,14 +179,14 @@ def _upsert_node(self, node_name: Hashable, metadata: dict) -> Hashable: if node_exists: self._connection.execute( self._node_table.update().where( - self._node_table.c[self._primary_key] == node_name + self._node_table.c[self._primary_key] == str(node_name) ), - **{"_metadata": metadata}, + parameters={"_metadata": metadata}, ) else: self._connection.execute( self._node_table.insert(), - **{self._primary_key: node_name, "_metadata": metadata}, + parameters={self._primary_key: node_name, "_metadata": metadata}, ) def all_nodes_as_iterable(self, include_metadata: bool = False) -> Generator: @@ -196,10 +201,16 @@ def all_nodes_as_iterable(self, include_metadata: bool = False) -> Generator: Generator: A generator of all nodes (arbitrary sort) """ - results = self._connection.execute(self._node_table.select()).fetchall() if include_metadata: - return [(row[self._primary_key], row["_metadata"]) for row in results] - return [row[self._primary_key] for row in results] + sql = self._node_table.select() + else: + sql = self._node_table.select().with_only_columns(self._node_table.c[self._primary_key]) + + results = [] + for x in self._connection.execute(sql): + results.append(x if include_metadata else x[0]) + + return results def has_node(self, u: Hashable) -> bool: """ @@ -214,7 +225,7 @@ def has_node(self, u: Hashable) -> bool: return len( self._connection.execute( self._node_table.select().where( - self._node_table.c[self._primary_key] == u + self._node_table.c[self._primary_key] == str(u) ) ).fetchall() ) @@ -245,7 +256,7 @@ def add_edge(self, u: Hashable, v: Hashable, metadata: dict): try: self._connection.execute( self._edge_table.insert(), - **{ + parameters={ self._primary_key: pk, self._edge_source_key: u, self._edge_target_key: v, @@ -260,11 +271,21 @@ def add_edge(self, u: Hashable, v: Hashable, metadata: dict): self._edge_table.update().where( self._edge_table.c[self._primary_key] == pk ), - **{"_metadata": existing_metadata}, + parameters={"_metadata": existing_metadata}, ) return pk + def add_edges_from(self, ebunch_to_add, **attr): + edges = [{ + self._primary_key: f"__{u}__{v}", + self._edge_source_key: u, + self._edge_target_key: v, + "_metadata": {**attr, **metadata}, + } for u, v, metadata in ebunch_to_add] + + self._connection.execute(self._edge_table.insert(), edges) + def all_edges_as_iterable(self, include_metadata: bool = False) -> Generator: """ Get a list of all edges in this graph, arbitrary sort. @@ -274,16 +295,18 @@ def all_edges_as_iterable(self, include_metadata: bool = False) -> Generator: Returns: Generator: A generator of all edges (arbitrary sort) - """ - return iter( - [ - (e.Source, e.Target, e._metadata) - if include_metadata - else (e.Source, e.Target) - for e in self._connection.execute(self._edge_table.select()).fetchall() - ] - ) + + columns = [ + self._node_table.c[self._edge_source_key], + self._node_table.c[self._edge_target_key] + ] + + if include_metadata: + columns.append(self._node_table.c["_metadata"]) + + sql = self._node_table.select().with_only_columns(columns) + return self._connection.execute(sql).fetchall() def get_node_by_id(self, node_name: Hashable): """ @@ -296,10 +319,11 @@ def get_node_by_id(self, node_name: Hashable): dict: The metadata associated with this node """ + res = ( self._connection.execute( self._node_table.select().where( - self._node_table.c[self._primary_key] == node_name + self._node_table.c[self._primary_key] == str(node_name) ) ) .fetchone() @@ -357,27 +381,30 @@ def get_node_neighbors( Generator """ + if self._directed: res = self._connection.execute( self._edge_table.select().where( - self._edge_table.c[self._edge_source_key] == u - ) + self._edge_table.c[self._edge_source_key] == str(u) + ).order_by(self._edge_table.c[self._primary_key]) ).fetchall() else: res = self._connection.execute( self._edge_table.select().where( or_( - (self._edge_table.c[self._edge_source_key] == u), - (self._edge_table.c[self._edge_target_key] == u), + (self._edge_table.c[self._edge_source_key] == str(u)), + (self._edge_table.c[self._edge_target_key] == str(u)), ) - ) + ).order_by(self._edge_table.c[self._primary_key]) ).fetchall() + res = [x._asdict() for x in res] + if include_metadata: return { ( r[self._edge_source_key] - if r[self._edge_source_key] != u + if r[self._edge_source_key] != str(u) else r[self._edge_target_key] ): r["_metadata"] for r in res @@ -387,7 +414,7 @@ def get_node_neighbors( [ ( r[self._edge_source_key] - if r[self._edge_source_key] != u + if r[self._edge_source_key] != str(u) else r[self._edge_target_key] ) for r in res @@ -410,24 +437,26 @@ def get_node_predecessors( if self._directed: res = self._connection.execute( self._edge_table.select().where( - self._edge_table.c[self._edge_target_key] == u - ) + self._edge_table.c[self._edge_target_key] == str(u) + ).order_by(self._edge_table.c[self._primary_key]) ).fetchall() else: res = self._connection.execute( self._edge_table.select().where( or_( - (self._edge_table.c[self._edge_target_key] == u), - (self._edge_table.c[self._edge_source_key] == u), + (self._edge_table.c[self._edge_target_key] == str(u)), + (self._edge_table.c[self._edge_source_key] == str(u)), ) - ) + ).order_by(self._edge_table.c[self._primary_key]) ).fetchall() + res = [x._asdict() for x in res] + if include_metadata: return { ( r[self._edge_source_key] - if r[self._edge_source_key] != u + if r[self._edge_source_key] != str(u) else r[self._edge_target_key] ): r["_metadata"] for r in res @@ -437,7 +466,7 @@ def get_node_predecessors( [ ( r[self._edge_source_key] - if r[self._edge_source_key] != u + if r[self._edge_source_key] != str(u) else r[self._edge_target_key] ) for r in res @@ -456,7 +485,7 @@ def get_node_count(self) -> Iterable: """ return self._connection.execute( - select([func.count()]).select_from(self._node_table) + select(func.count()).select_from(self._node_table) ).scalar() def out_degrees(self, nbunch=None): @@ -474,20 +503,20 @@ def out_degrees(self, nbunch=None): if nbunch is None: where_clause = None elif isinstance(nbunch, (list, tuple)): - where_clause = self._edge_table.c[self._edge_source_key].in_(nbunch) + where_clause = self._edge_table.c[self._edge_source_key].in_([str(x) for x in nbunch]) else: # single node: - where_clause = self._edge_table.c[self._edge_source_key] == nbunch + where_clause = self._edge_table.c[self._edge_source_key] == str(nbunch) if self._directed: query = ( - select([self._edge_table.c[self._edge_source_key], func.count()]) + select(self._edge_table.c[self._edge_source_key], func.count()) .select_from(self._edge_table) .group_by(self._edge_table.c[self._edge_source_key]) ) else: query = ( - select([self._edge_table.c[self._edge_source_key], func.count()]) + select(self._edge_table.c[self._edge_source_key], func.count()) .select_from(self._edge_table) .group_by(self._edge_table.c[self._edge_source_key]) ) @@ -496,8 +525,8 @@ def out_degrees(self, nbunch=None): query = query.where(where_clause) results = { - r[self._edge_source_key]: r[1] - for r in self._connection.execute(query).fetchall() + r[0]: r[1] + for r in self._connection.execute(query) } if nbunch and not isinstance(nbunch, (list, tuple)): @@ -519,20 +548,20 @@ def in_degrees(self, nbunch=None): if nbunch is None: where_clause = None elif isinstance(nbunch, (list, tuple)): - where_clause = self._edge_table.c[self._edge_target_key].in_(nbunch) + where_clause = self._edge_table.c[self._edge_target_key].in_([str(x) for x in nbunch]) else: # single node: - where_clause = self._edge_table.c[self._edge_target_key] == nbunch + where_clause = self._edge_table.c[self._edge_target_key] == str(nbunch) if self._directed: query = ( - select([self._edge_table.c[self._edge_target_key], func.count()]) + select(self._edge_table.c[self._edge_target_key], func.count()) .select_from(self._edge_table) .group_by(self._edge_table.c[self._edge_target_key]) ) else: query = ( - select([self._edge_table.c[self._edge_target_key], func.count()]) + select(self._edge_table.c[self._edge_target_key], func.count()) .select_from(self._edge_table) .group_by(self._edge_table.c[self._edge_target_key]) ) @@ -541,8 +570,8 @@ def in_degrees(self, nbunch=None): query = query.where(where_clause) results = { - r[self._edge_target_key]: r[1] - for r in self._connection.execute(query).fetchall() + r[0]: r[1] + for r in self._connection.execute(query) } if nbunch and not isinstance(nbunch, (list, tuple)): diff --git a/grand/backends/backend.py b/grand/backends/backend.py index b1c5f2b..83e0b24 100644 --- a/grand/backends/backend.py +++ b/grand/backends/backend.py @@ -65,6 +65,17 @@ def add_node(self, node_name: Hashable, metadata: dict): """ ... + def add_nodes_from(self, nodes_for_adding, **attr): + """ + Add nodes to the graph. + + Arguments: + nodes_for_adding: nodes to add + attr: additional attributes + """ + for node, metadata in nodes_for_adding: + self.add_node(node, {**attr, **metadata}) + def get_node_by_id(self, node_name: Hashable): """ Return the data associated with a node. @@ -125,6 +136,17 @@ def add_edge(self, u: Hashable, v: Hashable, metadata: dict): """ ... + def add_edges_from(self, ebunch_to_add, **attr): + """ + Add new edges to the graph. + + Arguments: + ebunch_to_add: list of (source, target, metadata) + attr: additional common attributes + """ + for u, v, metadata in ebunch_to_add: + self.add_edge(u, v, {**attr, **metadata}) + def all_edges_as_iterable(self, include_metadata: bool = False) -> Collection: """ Get a list of all edges in this graph, arbitrary sort. @@ -287,13 +309,17 @@ class InMemoryCachedBackend(CachedBackend): _default_uncacheable_methods = [ "add_node", + "add_nodes_from", "add_edge", + "add_edges_from", "ingest_from_edgelist_dataframe", ] _default_write_methods = [ "add_node", + "add_nodes_from", "add_edge", + "add_edges_from", "ingest_from_edgelist_dataframe", ] diff --git a/grand/dialects/__init__.py b/grand/dialects/__init__.py index 5bfe6e4..29c012a 100644 --- a/grand/dialects/__init__.py +++ b/grand/dialects/__init__.py @@ -101,9 +101,15 @@ def __init__(self, parent: "Graph"): def add_node(self, name: Hashable, **kwargs): return self.parent.backend.add_node(name, kwargs) + def add_nodes_from(self, nodes_for_adding, **attr): + return self.parent.backend.add_nodes_from(nodes_for_adding, **attr) + def add_edge(self, u: Hashable, v: Hashable, **kwargs): return self.parent.backend.add_edge(u, v, kwargs) + def add_edges_from(self, ebunch_to_add, **attr): + return self.parent.backend.add_edges_from(ebunch_to_add, **attr) + def remove_node(self, name: Hashable): if hasattr(self.parent.backend, "remove_node"): return self.parent.backend.remove_node(name)