Skip to content

Commit

Permalink
test: add more tests for PyTreeSpec.one_level
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Dec 14, 2024
1 parent 5f626f9 commit 672fe13
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ PyTreeSpec Functions
treespec_entry
treespec_children
treespec_child
treespec_one_level
treespec_transform
treespec_is_leaf
treespec_is_strict_leaf
Expand All @@ -146,6 +147,7 @@ PyTreeSpec Functions
.. autofunction:: treespec_entry
.. autofunction:: treespec_children
.. autofunction:: treespec_child
.. autofunction:: treespec_one_level
.. autofunction:: treespec_transform
.. autofunction:: treespec_is_leaf
.. autofunction:: treespec_is_strict_leaf
Expand Down
2 changes: 2 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
treespec_list,
treespec_namedtuple,
treespec_none,
treespec_one_level,
treespec_ordereddict,
treespec_paths,
treespec_structseq,
Expand Down Expand Up @@ -171,6 +172,7 @@
'treespec_entry',
'treespec_children',
'treespec_child',
'treespec_one_level',
'treespec_transform',
'treespec_is_leaf',
'treespec_is_strict_leaf',
Expand Down
35 changes: 34 additions & 1 deletion optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
'treespec_entry',
'treespec_children',
'treespec_child',
'treespec_one_level',
'treespec_transform',
'treespec_is_leaf',
'treespec_is_strict_leaf',
Expand Down Expand Up @@ -2540,6 +2541,12 @@ def treespec_paths(treespec: PyTreeSpec) -> list[tuple[Any, ...]]:
"""Return a list of paths to the leaves of a treespec.
See also :func:`tree_flatten_with_path`, :func:`tree_paths`, and :meth:`PyTreeSpec.paths`.
>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)})
>>> treespec
PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)})
>>> treespec_paths(treespec)
[('a', 0), ('a', 1, 0), ('a', 1, 1), ('b',), ('c', 0)]
"""
return treespec.paths()

Expand Down Expand Up @@ -2570,6 +2577,12 @@ def treespec_entries(treespec: PyTreeSpec) -> list[Any]:
See also :func:`treespec_entry`, :func:`treespec_paths`, :func:`treespec_children`,
and :meth:`PyTreeSpec.entries`.
>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)})
>>> treespec
PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)})
>>> treespec_entries(treespec)
['a', 'b', 'c']
"""
return treespec.entries()

Expand All @@ -2586,7 +2599,13 @@ def treespec_children(treespec: PyTreeSpec) -> list[PyTreeSpec]:
"""Return a list of treespecs for the children of a treespec.
See also :func:`treespec_child`, :func:`treespec_paths`, :func:`treespec_entries`,
and :meth:`PyTreeSpec.children`.
:func:`treespec_one_level`, and :meth:`PyTreeSpec.children`.
>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)})
>>> treespec
PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)})
>>> treespec_children(treespec)
[PyTreeSpec((*, [*, *])), PyTreeSpec(*), PyTreeSpec((*, None))]
"""
return treespec.children()

Expand All @@ -2599,6 +2618,20 @@ def treespec_child(treespec: PyTreeSpec, index: int) -> PyTreeSpec:
return treespec.child(index)


def treespec_one_level(treespec: PyTreeSpec) -> PyTreeSpec | None:
"""Return the one-level tree structure of the treespec or :data:`None` if the treespec is a leaf.
See also :func:`treespec_children` and :meth:`PyTreeSpec.one_level`.
>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)})
>>> treespec
PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)})
>>> treespec_one_level(treespec)
PyTreeSpec({'a': *, 'b': *, 'c': *})
"""
return treespec.one_level()


def treespec_transform(
treespec: PyTreeSpec,
f_node: Callable[[PyTreeSpec], PyTreeSpec] | None = None,
Expand Down
51 changes: 51 additions & 0 deletions tests/test_treespec.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,57 @@ def test_treespec_child(
]


@parametrize(
tree=TREES,
none_is_leaf=[False, True],
namespace=['', 'undefined', 'namespace'],
dict_should_be_sorted=[False, True],
dict_session_namespace=['', 'undefined', 'namespace'],
)
def test_treespec_one_level(
tree,
none_is_leaf,
namespace,
dict_should_be_sorted,
dict_session_namespace,
):
with optree.dict_insertion_ordered(
not dict_should_be_sorted,
namespace=dict_session_namespace or GLOBAL_NAMESPACE,
):
treespec = optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace)
if treespec.type is None:
assert treespec.is_leaf()
assert optree.treespec_one_level(treespec) is None
assert optree.treespec_children(treespec) == []
assert treespec.num_children == 0
else:
one_level = optree.treespec_one_level(treespec)
num_children = treespec.num_children
assert not treespec.is_leaf()
assert not one_level.is_leaf()
assert one_level.one_level() == one_level
assert one_level.num_nodes == num_children + 1
assert one_level.num_leaves == num_children
assert one_level.num_children == num_children
assert len(one_level) == num_children
assert optree.treespec_entries(one_level) == optree.treespec_entries(treespec)
assert all(optree.treespec_child(one_level, i).is_leaf() for i in range(num_children))
assert all(child.is_leaf() for child in optree.treespec_children(one_level))
assert optree.treespec_is_prefix(one_level, treespec)
assert optree.treespec_is_suffix(treespec, one_level)
assert (
optree.treespec_from_collection(
optree.tree_unflatten(one_level, treespec.children()),
none_is_leaf=none_is_leaf,
namespace=namespace,
)
== treespec
)
it = iter(treespec.children())
assert optree.treespec_transform(one_level, None, lambda _: next(it)) == treespec


def test_treespec_transform():
treespec = optree.tree_structure(((1, 2, 3), (4,)))
assert optree.treespec_transform(treespec) == treespec
Expand Down

0 comments on commit 672fe13

Please sign in to comment.