diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 2a7b599b..ec8c1c6c 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -123,6 +123,7 @@ PyTreeSpec Functions treespec_entry treespec_children treespec_child + treespec_one_level treespec_transform treespec_is_leaf treespec_is_strict_leaf @@ -146,6 +147,7 @@ PyTreeSpec Functions .. autofunction:: treespec_entry .. autofunction:: treespec_children .. autofunction:: treespec_child +.. autofunction:: treespec_one_level .. autofunction:: treespec_transform .. autofunction:: treespec_is_leaf .. autofunction:: treespec_is_strict_leaf diff --git a/optree/__init__.py b/optree/__init__.py index 7bb07e09..12d985e7 100644 --- a/optree/__init__.py +++ b/optree/__init__.py @@ -87,6 +87,7 @@ treespec_list, treespec_namedtuple, treespec_none, + treespec_one_level, treespec_ordereddict, treespec_paths, treespec_structseq, @@ -171,6 +172,7 @@ 'treespec_entry', 'treespec_children', 'treespec_child', + 'treespec_one_level', 'treespec_transform', 'treespec_is_leaf', 'treespec_is_strict_leaf', diff --git a/optree/ops.py b/optree/ops.py index 9b615d75..c3be21ee 100644 --- a/optree/ops.py +++ b/optree/ops.py @@ -94,6 +94,7 @@ 'treespec_entry', 'treespec_children', 'treespec_child', + 'treespec_one_level', 'treespec_transform', 'treespec_is_leaf', 'treespec_is_strict_leaf', @@ -2540,6 +2541,12 @@ def treespec_paths(treespec: PyTreeSpec) -> list[tuple[Any, ...]]: """Return a list of paths to the leaves of a treespec. See also :func:`tree_flatten_with_path`, :func:`tree_paths`, and :meth:`PyTreeSpec.paths`. + + >>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) + >>> treespec + PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) + >>> treespec_paths(treespec) + [('a', 0), ('a', 1, 0), ('a', 1, 1), ('b',), ('c', 0)] """ return treespec.paths() @@ -2570,6 +2577,12 @@ def treespec_entries(treespec: PyTreeSpec) -> list[Any]: See also :func:`treespec_entry`, :func:`treespec_paths`, :func:`treespec_children`, and :meth:`PyTreeSpec.entries`. + + >>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) + >>> treespec + PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) + >>> treespec_entries(treespec) + ['a', 'b', 'c'] """ return treespec.entries() @@ -2586,7 +2599,13 @@ def treespec_children(treespec: PyTreeSpec) -> list[PyTreeSpec]: """Return a list of treespecs for the children of a treespec. See also :func:`treespec_child`, :func:`treespec_paths`, :func:`treespec_entries`, - and :meth:`PyTreeSpec.children`. + :func:`treespec_one_level`, and :meth:`PyTreeSpec.children`. + + >>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) + >>> treespec + PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) + >>> treespec_children(treespec) + [PyTreeSpec((*, [*, *])), PyTreeSpec(*), PyTreeSpec((*, None))] """ return treespec.children() @@ -2599,6 +2618,20 @@ def treespec_child(treespec: PyTreeSpec, index: int) -> PyTreeSpec: return treespec.child(index) +def treespec_one_level(treespec: PyTreeSpec) -> PyTreeSpec | None: + """Return the one-level tree structure of the treespec or :data:`None` if the treespec is a leaf. + + See also :func:`treespec_children` and :meth:`PyTreeSpec.one_level`. + + >>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) + >>> treespec + PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) + >>> treespec_one_level(treespec) + PyTreeSpec({'a': *, 'b': *, 'c': *}) + """ + return treespec.one_level() + + def treespec_transform( treespec: PyTreeSpec, f_node: Callable[[PyTreeSpec], PyTreeSpec] | None = None, diff --git a/tests/test_treespec.py b/tests/test_treespec.py index 10018a87..095d223b 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -923,6 +923,57 @@ def test_treespec_child( ] +@parametrize( + tree=TREES, + none_is_leaf=[False, True], + namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], +) +def test_treespec_one_level( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + treespec = optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace) + if treespec.type is None: + assert treespec.is_leaf() + assert optree.treespec_one_level(treespec) is None + assert optree.treespec_children(treespec) == [] + assert treespec.num_children == 0 + else: + one_level = optree.treespec_one_level(treespec) + num_children = treespec.num_children + assert not treespec.is_leaf() + assert not one_level.is_leaf() + assert one_level.one_level() == one_level + assert one_level.num_nodes == num_children + 1 + assert one_level.num_leaves == num_children + assert one_level.num_children == num_children + assert len(one_level) == num_children + assert optree.treespec_entries(one_level) == optree.treespec_entries(treespec) + assert all(optree.treespec_child(one_level, i).is_leaf() for i in range(num_children)) + assert all(child.is_leaf() for child in optree.treespec_children(one_level)) + assert optree.treespec_is_prefix(one_level, treespec) + assert optree.treespec_is_suffix(treespec, one_level) + assert ( + optree.treespec_from_collection( + optree.tree_unflatten(one_level, treespec.children()), + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + == treespec + ) + it = iter(treespec.children()) + assert optree.treespec_transform(one_level, None, lambda _: next(it)) == treespec + + def test_treespec_transform(): treespec = optree.tree_structure(((1, 2, 3), (4,))) assert optree.treespec_transform(treespec) == treespec