Skip to content

Commit

Permalink
add level2_graph parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
ceesem committed Apr 7, 2024
1 parent d8a30e1 commit 3c67675
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 22 deletions.
2 changes: 1 addition & 1 deletion pcg_skel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from . import utils
from . import chunk_tools

__version__ = "1.0.0"
__version__ = "1.0.1"
70 changes: 49 additions & 21 deletions pcg_skel/pcg_skel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import cloudvolume
import warnings
import numpy as np
import datetime

from caveclient import CAVEclient
from meshparty import meshwork, skeletonize, trimesh_io
Expand All @@ -26,6 +27,7 @@ def pcg_graph(
return_l2dict: bool = False,
nan_rounds: int = 10,
require_complete: bool = False,
level2_graph: Optional[np.ndarray] = None,
):
"""Compute the level 2 spatial graph (or mesh) of a given root id using the l2cache.
Expand All @@ -46,6 +48,10 @@ def pcg_graph(
If vertices are missing (or not computed), this sets the number of iterations for smoothing over them.
require_complete : bool
If True, raise an Exception if any vertices are missing from the cache.
level2_graph : np.ndarray, optional
Level 2 graph for the root id as returned by client.chunkedgraph.level2_chunk_graph.
A list of lists of edges between level 2 chunks, as defined by their chunk ids.
If None, will query the chunkedgraph for the level 2 graph. Optional, by default None.
Returns
Expand All @@ -60,7 +66,11 @@ def pcg_graph(
if cv is None:
cv = client.info.segmentation_cloudvolume(progress=False)

lvl2_eg = client.chunkedgraph.level2_chunk_graph(root_id)
if level2_graph is None:
lvl2_eg = client.chunkedgraph.level2_chunk_graph(root_id)
else:
lvl2_eg = level2_graph

eg, l2dict_mesh, l2dict_r_mesh, x_ch = chunk_tools.build_spatial_graph(
lvl2_eg,
cv,
Expand Down Expand Up @@ -112,6 +122,11 @@ def pcg_skeleton_direct(
Distance (in nanometers) for soma collapse.
root_id : int, optional
Root id of the segment, used in metadata. Optional, by default None.
level2_graph : np.ndarray, optional
Level 2 graph for the root id as returned by client.chunkedgraph.level2_chunk_graph.
A list of lists of edges between level 2 chunks, as defined by their chunk ids.
If None, will query the chunkedgraph for the level 2 graph. Optional, by default None.
Returns
-------
Expand Down Expand Up @@ -158,6 +173,7 @@ def pcg_skeleton(
collapse_radius: Numeric = 7500,
nan_rounds: int = 10,
require_complete: bool = False,
level2_graph: Optional[np.ndarray] = None,
):
"""Produce a skeleton from the level 2 graph.
Parameters
Expand Down Expand Up @@ -192,6 +208,10 @@ def pcg_skeleton(
If vertices are missing (or not computed), this sets the number of iterations for smoothing over them.
require_complete : bool, optional
If True, raise an Exception if any vertices are missing from the cache.
level2_graph : np.ndarray, optional
Level 2 graph for the root id as returned by client.chunkedgraph.level2_chunk_graph.
A list of lists of edges between level 2 chunks, as defined by their chunk ids.
If None, will query the chunkedgraph for the level 2 graph. Optional, by default None.
Returns
-------
Expand Down Expand Up @@ -220,6 +240,7 @@ def pcg_skeleton(
return_l2dict=True,
nan_rounds=nan_rounds,
require_complete=require_complete,
level2_graph=level2_graph,
)

metameta = {"space": "l2cache", "datastack": client.datastack_name}
Expand Down Expand Up @@ -256,26 +277,27 @@ def pcg_skeleton(


def pcg_meshwork(
root_id,
datastack_name=None,
client=None,
cv=None,
root_point=None,
root_point_resolution=None,
collapse_soma=False,
collapse_radius=DEFAULT_COLLAPSE_RADIUS,
synapses=None,
synapse_table=None,
remove_self_synapse=True,
live_query=False,
timestamp=None,
invalidation_d=DEFAULT_INVALIDATION_D,
require_complete=False,
metadata=False,
synapse_partners=False,
synapse_point_resolution=[1, 1, 1],
synapse_representative_point_pre="ctr_pt_position",
synapse_representative_point_post="ctr_pt_position",
root_id: int,
datastack_name: Optional[str] = None,
client: Optional[CAVEclient] = None,
cv: Optional[cloudvolume.CloudVolume] = None,
root_point: Optional[list] = None,
root_point_resolution: Optional[list] = None,
collapse_soma: bool = False,
collapse_radius: Numeric = DEFAULT_COLLAPSE_RADIUS,
synapses: Optional[Union[bool, str]] = None,
synapse_table: Optional[str] = None,
remove_self_synapse: bool = True,
live_query: bool = False,
timestamp: Optional[datetime.datetime] = None,
invalidation_d: Numeric = DEFAULT_INVALIDATION_D,
require_complete: bool = False,
metadata: bool = False,
synapse_partners: bool = False,
synapse_point_resolution: list = [1, 1, 1],
synapse_representative_point_pre: str = "ctr_pt_position",
synapse_representative_point_post: str = "ctr_pt_position",
level2_graph: Optional[np.ndarray] = None,
) -> meshwork.Meshwork:
"""Generate a meshwork file based on the level 2 graph.
Expand Down Expand Up @@ -350,6 +372,7 @@ def pcg_meshwork(
return_mesh=True,
return_l2dict_mesh=True,
require_complete=require_complete,
level2_graph=level2_graph,
)

nrn = meshwork.Meshwork(mesh, seg_id=root_id, skeleton=sk)
Expand Down Expand Up @@ -581,6 +604,11 @@ def coord_space_meshwork(
Invalidation radius in hops for the mesh skeletonization along the chunk adjacency graph, by default 3
require_complete : bool, optional
If True, raise an Exception if any vertices are missing from the cache, by default False
level2_graph : np.ndarray, optional
Level 2 graph for the root id as returned by client.chunkedgraph.level2_chunk_graph.
A list of lists of edges between level 2 chunks, as defined by their chunk ids.
If None, will query the chunkedgraph for the level 2 graph. Optional, by default None.
Returns
-------
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ def mip_resolution(self):
test_post_synapses = pd.read_feather(base_path / "data/post_syn.feather")


@pytest.fixture()
def test_l2eg():
return test_l2graph


@pytest.fixture()
def test_client(mocker):
client = CAVEclient(TEST_DATASTACK, info_cache=INFO_CACHE)
Expand Down
14 changes: 14 additions & 0 deletions tests/test_pcg_skel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@ def test_pcg_skeleton(test_client, root_id, center_pt):
assert graph_sk.path_length() > 0


def test_pcg_skeleton_prebaked(test_client, root_id, center_pt, test_l2eg):
graph_sk = pcg_skel.pcg_skeleton(
root_id,
test_client,
collapse_radius=True,
root_point=center_pt,
root_point_resolution=[4, 4, 40],
level2_graph=test_l2eg,
)
assert graph_sk.vertices is not None
assert graph_sk.edges is not None
assert graph_sk.path_length() > 0


def test_pcg_skeleton_direct(test_client, root_id, center_pt):
graph_m = pcg_skel.pcg_graph(root_id, test_client)
sk = pcg_skel.pcg_skeleton_direct(
Expand Down

0 comments on commit 3c67675

Please sign in to comment.