Skip to content

Commit

Permalink
breakup long fn
Browse files Browse the repository at this point in the history
  • Loading branch information
akhileshh committed May 12, 2024
1 parent 6a2c5da commit fa75e17
Showing 1 changed file with 48 additions and 48 deletions.
96 changes: 48 additions & 48 deletions pychunkedgraph/ingest/create/parent_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,50 @@ def _write_components_helper(args):
_write(cg, layer, pcoords, ccs, node_layer_d, time_stamp)


def _children_rows(
cg: ChunkedGraph, parent_id, children: Sequence, cx_edges_d: dict, time_stamp
):
"""
Update children rows to point to the parent_id, collect cached children
cross chunk edges to lift and update parent cross chunk edges.
Returns list of mutations to children and list of children cross edges.
"""
rows = []
children_cx_edges = []
for child in children:
node_layer = cg.get_chunk_layer(child)
row_id = serializers.serialize_uint64(child)
val_dict = {attributes.Hierarchy.Parent: parent_id}
node_cx_edges_d = cx_edges_d.get(child, {})
if not node_cx_edges_d:
rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp))
continue
for layer in range(node_layer, cg.meta.layer_count):
if not layer in node_cx_edges_d:
continue
layer_edges = node_cx_edges_d[layer]
nodes = np.unique(layer_edges)
parents = cg.get_roots(nodes, stop_layer=node_layer, ceil=False)
edge_parents_d = dict(zip(nodes, parents))
layer_edges = fastremap.remap(
layer_edges, edge_parents_d, preserve_missing_labels=True
)
layer_edges = np.unique(layer_edges, axis=0)
col = attributes.Connectivity.CrossChunkEdge[layer]
val_dict[col] = layer_edges
node_cx_edges_d[layer] = layer_edges
children_cx_edges.append(node_cx_edges_d)
rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp))
return rows, children_cx_edges


def _write(
cg: ChunkedGraph,
layer_id,
parent_coords,
components,
node_layer_d,
time_stamp,
ts,
use_threads=True,
):
parent_layers = range(layer_id, cg.meta.layer_count + 1)
Expand All @@ -175,71 +212,34 @@ def _write(
x, y, z = parent_coords
parent_chunk_id = cg.get_chunk_id(layer=layer_id, x=x, y=y, z=z)
parent_chunk_id_dict = cg.get_parent_chunk_id_dict(parent_chunk_id)

for parent_layer in parent_layers:
if len(cc_connections[parent_layer]) == 0:
continue

parent_chunk_id = parent_chunk_id_dict[parent_layer]
reserved_parent_ids = cg.id_client.create_node_ids(
parent_chunk_id,
size=len(cc_connections[parent_layer]),
root_chunk=parent_layer == cg.meta.layer_count and use_threads,
)

for i_cc, node_ids in enumerate(cc_connections[parent_layer]):
parent_id = reserved_parent_ids[i_cc]

for i_cc, children in enumerate(cc_connections[parent_layer]):
parent = reserved_parent_ids[i_cc]
if layer_id == 3:
# when layer 3 is being processed, children chunks are at layer 2
# layer 2 chunks at this time will only have atomic cross edges
cx_edges_d = cg.get_atomic_cross_edges(node_ids)
cx_edges_d = cg.get_atomic_cross_edges(children)
else:
# children are from abstract chunks
cx_edges_d = cg.get_cross_chunk_edges(node_ids, raw_only=True)

children_cx_edges = []
for node in node_ids:
node_layer = cg.get_chunk_layer(node)
row_id = serializers.serialize_uint64(node)
val_dict = {attributes.Hierarchy.Parent: parent_id}

node_cx_edges_d = cx_edges_d.get(node, {})
if not node_cx_edges_d:
rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp))
continue

for layer in range(node_layer, cg.meta.layer_count):
if not layer in node_cx_edges_d:
continue
layer_edges = node_cx_edges_d[layer]
nodes = np.unique(layer_edges)
parents = cg.get_roots(nodes, stop_layer=node_layer, ceil=False)

edge_parents_d = dict(zip(nodes, parents))
layer_edges = fastremap.remap(
layer_edges, edge_parents_d, preserve_missing_labels=True
)
layer_edges = np.unique(layer_edges, axis=0)

col = attributes.Connectivity.CrossChunkEdge[layer]
val_dict[col] = layer_edges
node_cx_edges_d[layer] = layer_edges
children_cx_edges.append(node_cx_edges_d)
rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp))

row_id = serializers.serialize_uint64(parent_id)
val_dict = {attributes.Hierarchy.Child: node_ids}
parent_cx_edges_d = concatenate_cross_edge_dicts(
children_cx_edges, unique=True
)
cx_edges_d = cg.get_cross_chunk_edges(children, raw_only=True)
_rows, cx_edges = _children_rows(cg, parent, children, cx_edges_d, ts)
rows.extend(_rows)
row_id = serializers.serialize_uint64(parent)
val_dict = {attributes.Hierarchy.Child: children}
parent_cx_edges_d = concatenate_cross_edge_dicts(cx_edges, unique=True)
for layer in range(parent_layer, cg.meta.layer_count):
if not layer in parent_cx_edges_d:
continue
col = attributes.Connectivity.CrossChunkEdge[layer]
val_dict[col] = parent_cx_edges_d[layer]

rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp))
rows.append(cg.client.mutate_row(row_id, val_dict, ts))
if len(rows) > 100000:
cg.client.write(rows)
rows = []
Expand Down

0 comments on commit fa75e17

Please sign in to comment.