From fa75e17c083403c82b2abad9da027e9c463a739e Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Sun, 12 May 2024 18:18:05 +0000 Subject: [PATCH] breakup long fn --- pychunkedgraph/ingest/create/parent_layer.py | 96 ++++++++++---------- 1 file changed, 48 insertions(+), 48 deletions(-) diff --git a/pychunkedgraph/ingest/create/parent_layer.py b/pychunkedgraph/ingest/create/parent_layer.py index 09be61407..a777d9efc 100644 --- a/pychunkedgraph/ingest/create/parent_layer.py +++ b/pychunkedgraph/ingest/create/parent_layer.py @@ -154,13 +154,50 @@ def _write_components_helper(args): _write(cg, layer, pcoords, ccs, node_layer_d, time_stamp) +def _children_rows( + cg: ChunkedGraph, parent_id, children: Sequence, cx_edges_d: dict, time_stamp +): + """ + Update children rows to point to the parent_id, collect cached children + cross chunk edges to lift and update parent cross chunk edges. + Returns list of mutations to children and list of children cross edges. + """ + rows = [] + children_cx_edges = [] + for child in children: + node_layer = cg.get_chunk_layer(child) + row_id = serializers.serialize_uint64(child) + val_dict = {attributes.Hierarchy.Parent: parent_id} + node_cx_edges_d = cx_edges_d.get(child, {}) + if not node_cx_edges_d: + rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp)) + continue + for layer in range(node_layer, cg.meta.layer_count): + if not layer in node_cx_edges_d: + continue + layer_edges = node_cx_edges_d[layer] + nodes = np.unique(layer_edges) + parents = cg.get_roots(nodes, stop_layer=node_layer, ceil=False) + edge_parents_d = dict(zip(nodes, parents)) + layer_edges = fastremap.remap( + layer_edges, edge_parents_d, preserve_missing_labels=True + ) + layer_edges = np.unique(layer_edges, axis=0) + col = attributes.Connectivity.CrossChunkEdge[layer] + val_dict[col] = layer_edges + node_cx_edges_d[layer] = layer_edges + children_cx_edges.append(node_cx_edges_d) + rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp)) + return rows, children_cx_edges + + def _write( cg: ChunkedGraph, layer_id, parent_coords, components, node_layer_d, - time_stamp, + ts, use_threads=True, ): parent_layers = range(layer_id, cg.meta.layer_count + 1) @@ -175,71 +212,34 @@ def _write( x, y, z = parent_coords parent_chunk_id = cg.get_chunk_id(layer=layer_id, x=x, y=y, z=z) parent_chunk_id_dict = cg.get_parent_chunk_id_dict(parent_chunk_id) - for parent_layer in parent_layers: if len(cc_connections[parent_layer]) == 0: continue - parent_chunk_id = parent_chunk_id_dict[parent_layer] reserved_parent_ids = cg.id_client.create_node_ids( parent_chunk_id, size=len(cc_connections[parent_layer]), root_chunk=parent_layer == cg.meta.layer_count and use_threads, ) - - for i_cc, node_ids in enumerate(cc_connections[parent_layer]): - parent_id = reserved_parent_ids[i_cc] - + for i_cc, children in enumerate(cc_connections[parent_layer]): + parent = reserved_parent_ids[i_cc] if layer_id == 3: # when layer 3 is being processed, children chunks are at layer 2 # layer 2 chunks at this time will only have atomic cross edges - cx_edges_d = cg.get_atomic_cross_edges(node_ids) + cx_edges_d = cg.get_atomic_cross_edges(children) else: - # children are from abstract chunks - cx_edges_d = cg.get_cross_chunk_edges(node_ids, raw_only=True) - - children_cx_edges = [] - for node in node_ids: - node_layer = cg.get_chunk_layer(node) - row_id = serializers.serialize_uint64(node) - val_dict = {attributes.Hierarchy.Parent: parent_id} - - node_cx_edges_d = cx_edges_d.get(node, {}) - if not node_cx_edges_d: - rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp)) - continue - - for layer in range(node_layer, cg.meta.layer_count): - if not layer in node_cx_edges_d: - continue - layer_edges = node_cx_edges_d[layer] - nodes = np.unique(layer_edges) - parents = cg.get_roots(nodes, stop_layer=node_layer, ceil=False) - - edge_parents_d = dict(zip(nodes, parents)) - layer_edges = fastremap.remap( - layer_edges, edge_parents_d, preserve_missing_labels=True - ) - layer_edges = np.unique(layer_edges, axis=0) - - col = attributes.Connectivity.CrossChunkEdge[layer] - val_dict[col] = layer_edges - node_cx_edges_d[layer] = layer_edges - children_cx_edges.append(node_cx_edges_d) - rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp)) - - row_id = serializers.serialize_uint64(parent_id) - val_dict = {attributes.Hierarchy.Child: node_ids} - parent_cx_edges_d = concatenate_cross_edge_dicts( - children_cx_edges, unique=True - ) + cx_edges_d = cg.get_cross_chunk_edges(children, raw_only=True) + _rows, cx_edges = _children_rows(cg, parent, children, cx_edges_d, ts) + rows.extend(_rows) + row_id = serializers.serialize_uint64(parent) + val_dict = {attributes.Hierarchy.Child: children} + parent_cx_edges_d = concatenate_cross_edge_dicts(cx_edges, unique=True) for layer in range(parent_layer, cg.meta.layer_count): if not layer in parent_cx_edges_d: continue col = attributes.Connectivity.CrossChunkEdge[layer] val_dict[col] = parent_cx_edges_d[layer] - - rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp)) + rows.append(cg.client.mutate_row(row_id, val_dict, ts)) if len(rows) > 100000: cg.client.write(rows) rows = []