From 92428dd6f9d9bb7a2760fd930b5878479a84c2cd Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sun, 15 Dec 2024 14:31:53 +0800 Subject: [PATCH] feat(treespec): add method `PyTreeSpec.one_level` (#179) --- CHANGELOG.md | 1 + conda-recipe.yaml | 2 +- docs/source/ops.rst | 4 ++ include/optree/treespec.h | 9 +++ optree/_C.pyi | 2 + optree/__init__.py | 4 ++ optree/ops.py | 112 +++++++++++++++++++++++++++++++------- pyproject.toml | 2 +- src/optree.cpp | 21 +++++-- src/treespec/treespec.cpp | 57 ++++++++++--------- tests/test_treespec.py | 105 +++++++++++++++++++++++++++++++++++ 11 files changed, 267 insertions(+), 52 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index babea005..75fb53d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add method `PyTreeSpec.one_level` and `PyTreeSpec.is_one_level` by [@XuehaiPan](https://github.com/XuehaiPan) in [#179](https://github.com/metaopt/optree/pull/179). - Add method `PyTreeSpec.transform` by [@XuehaiPan](https://github.com/XuehaiPan) in [#177](https://github.com/metaopt/optree/pull/177). ### Changed diff --git a/conda-recipe.yaml b/conda-recipe.yaml index 17ebd9c3..f3d199dd 100644 --- a/conda-recipe.yaml +++ b/conda-recipe.yaml @@ -39,7 +39,7 @@ dependencies: - pybind11 >= 2.13.1 # Benchmark - - pytorch::pytorch >= 2.0, < 2.4.0a0 + - pytorch::pytorch >= 2.0, < 2.6.0a0 - pytorch::torchvision - pytorch::pytorch-mutex = *=*cpu* - conda-forge::jax >= 0.4.6, < 0.5.0a0 diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 2a7b599b..fd21263b 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -123,9 +123,11 @@ PyTreeSpec Functions treespec_entry treespec_children treespec_child + treespec_one_level treespec_transform treespec_is_leaf treespec_is_strict_leaf + treespec_is_one_level treespec_is_prefix treespec_is_suffix treespec_leaf @@ -146,9 +148,11 @@ PyTreeSpec Functions .. autofunction:: treespec_entry .. autofunction:: treespec_children .. autofunction:: treespec_child +.. autofunction:: treespec_one_level .. autofunction:: treespec_transform .. autofunction:: treespec_is_leaf .. autofunction:: treespec_is_strict_leaf +.. autofunction:: treespec_is_one_level .. autofunction:: treespec_is_prefix .. autofunction:: treespec_is_suffix .. autofunction:: treespec_leaf diff --git a/include/optree/treespec.h b/include/optree/treespec.h index 1ab766d7..17c3738b 100644 --- a/include/optree/treespec.h +++ b/include/optree/treespec.h @@ -161,6 +161,10 @@ class PyTreeSpec { // Return the child at the given index of the PyTreeSpec. [[nodiscard]] std::unique_ptr Child(ssize_t index) const; + // Return the one-level structure of the PyTreeSpec. + [[nodiscard]] std::unique_ptr GetOneLevel( + const std::optional &node = std::nullopt) const; + [[nodiscard]] inline Py_ALWAYS_INLINE ssize_t GetNumLeaves() const { PYTREESPEC_SANITY_CHECK(*this); return m_traversal.back().num_leaves; @@ -195,6 +199,11 @@ class PyTreeSpec { return GetNumNodes() == 1; } + // Test whether this PyTreeSpec represents a one-level tree. + [[nodiscard]] inline Py_ALWAYS_INLINE bool IsOneLevel() const { + return GetNumNodes() == GetNumChildren() + 1 && GetNumLeaves() == GetNumChildren(); + } + // Return true if this PyTreeSpec is a prefix of `other`. [[nodiscard]] bool IsPrefix(const PyTreeSpec &other, const bool &strict = false) const; diff --git a/optree/_C.pyi b/optree/_C.pyi index 1df7f036..87639650 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -129,7 +129,9 @@ class PyTreeSpec: def entry(self, index: int) -> Any: ... def children(self) -> list[Self]: ... def child(self, index: int) -> Self: ... + def one_level(self) -> Self | None: ... def is_leaf(self, *, strict: bool = True) -> bool: ... + def is_one_level(self) -> bool: ... def is_prefix(self, other: Self, *, strict: bool = False) -> bool: ... def is_suffix(self, other: Self, *, strict: bool = False) -> bool: ... def __eq__(self, other: object) -> bool: ... diff --git a/optree/__init__.py b/optree/__init__.py index 7bb07e09..cf77a836 100644 --- a/optree/__init__.py +++ b/optree/__init__.py @@ -80,6 +80,7 @@ treespec_entry, treespec_from_collection, treespec_is_leaf, + treespec_is_one_level, treespec_is_prefix, treespec_is_strict_leaf, treespec_is_suffix, @@ -87,6 +88,7 @@ treespec_list, treespec_namedtuple, treespec_none, + treespec_one_level, treespec_ordereddict, treespec_paths, treespec_structseq, @@ -171,9 +173,11 @@ 'treespec_entry', 'treespec_children', 'treespec_child', + 'treespec_one_level', 'treespec_transform', 'treespec_is_leaf', 'treespec_is_strict_leaf', + 'treespec_is_one_level', 'treespec_is_prefix', 'treespec_is_suffix', 'treespec_leaf', diff --git a/optree/ops.py b/optree/ops.py index f9315afb..891c55e2 100644 --- a/optree/ops.py +++ b/optree/ops.py @@ -94,9 +94,11 @@ 'treespec_entry', 'treespec_children', 'treespec_child', + 'treespec_one_level', 'treespec_transform', 'treespec_is_leaf', 'treespec_is_strict_leaf', + 'treespec_is_one_level', 'treespec_is_prefix', 'treespec_is_suffix', 'treespec_leaf', @@ -903,31 +905,31 @@ def tree_map_with_accessor( {'x': "tree['x'] = 7", 'y': ("tree['y'][0] = 42", "tree['y'][1] = 64")} >>> tree_map_with_accessor(lambda a, x: x + len(a), {'x': 7, 'y': (42, 64), 'z': None}) {'x': 8, 'y': (44, 66), 'z': None} - >>> tree_map_with_accessor( # doctest: +IGNORE_WHITESPACE + >>> tree_map_with_accessor( # doctest: +IGNORE_WHITESPACE,ELLIPSIS ... lambda a, x: a, ... {'x': 7, 'y': (42, 64), 'z': {1.5: None}}, ... ) { - 'x': PyTreeAccessor(*['x'], (MappingEntry(key='x', type=),)), + 'x': PyTreeAccessor(*['x'], ...), 'y': ( - PyTreeAccessor(*['y'][0], (MappingEntry(key='y', type=), SequenceEntry(index=0, type=))), - PyTreeAccessor(*['y'][1], (MappingEntry(key='y', type=), SequenceEntry(index=1, type=))) + PyTreeAccessor(*['y'][0], ...), + PyTreeAccessor(*['y'][1], ...) ), 'z': {1.5: None} } - >>> tree_map_with_accessor( # doctest: +IGNORE_WHITESPACE + >>> tree_map_with_accessor( # doctest: +IGNORE_WHITESPACE,ELLIPSIS ... lambda a, x: a, ... {'x': 7, 'y': (42, 64), 'z': {1.5: None}}, ... none_is_leaf=True, ... ) { - 'x': PyTreeAccessor(*['x'], (MappingEntry(key='x', type=),)), + 'x': PyTreeAccessor(*['x'], ...), 'y': ( - PyTreeAccessor(*['y'][0], (MappingEntry(key='y', type=), SequenceEntry(index=0, type=))), - PyTreeAccessor(*['y'][1], (MappingEntry(key='y', type=), SequenceEntry(index=1, type=))) + PyTreeAccessor(*['y'][0], ...), + PyTreeAccessor(*['y'][1], ...) ), 'z': { - 1.5: PyTreeAccessor(*['z'][1.5], (MappingEntry(key='z', type=), MappingEntry(key=1.5, type=))) + 1.5: PyTreeAccessor(*['z'][1.5], ...) } } @@ -954,7 +956,7 @@ def tree_map_with_accessor( A new pytree with the same structure as ``tree`` but with the value at each leaf given by ``func(a, x, *xs)`` where ``(a, x)`` are the accessor and value at the corresponding leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in ``rests``. - """ # pylint: disable=line-too-long + """ leaves, treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace) flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] return treespec.unflatten(map(func, treespec.accessors(), *flat_args)) @@ -1322,7 +1324,7 @@ def tree_transpose_map_with_accessor( 'c': (5, 6) } } - >>> tree_transpose_map_with_accessor( # doctest: +IGNORE_WHITESPACE + >>> tree_transpose_map_with_accessor( # doctest: +IGNORE_WHITESPACE,ELLIPSIS ... lambda a, x: {'path': a.path, 'accessor': a, 'value': x}, ... tree, ... inner_treespec=tree_structure({'path': 0, 'accessor': 0, 'value': 0}), @@ -1335,16 +1337,16 @@ def tree_transpose_map_with_accessor( }, 'accessor': { 'b': ( - PyTreeAccessor(*['b'][0], (MappingEntry(key='b', type=), SequenceEntry(index=0, type=))), + PyTreeAccessor(*['b'][0], ...), [ - PyTreeAccessor(*['b'][1][0], (MappingEntry(key='b', type=), SequenceEntry(index=1, type=), SequenceEntry(index=0, type=))), - PyTreeAccessor(*['b'][1][1], (MappingEntry(key='b', type=), SequenceEntry(index=1, type=), SequenceEntry(index=1, type=))) + PyTreeAccessor(*['b'][1][0], ...), + PyTreeAccessor(*['b'][1][1], ...) ] ), - 'a': PyTreeAccessor(*['a'], (MappingEntry(key='a', type=),)), + 'a': PyTreeAccessor(*['a'], ...), 'c': ( - PyTreeAccessor(*['c'][0], (MappingEntry(key='c', type=), SequenceEntry(index=0, type=))), - PyTreeAccessor(*['c'][1], (MappingEntry(key='c', type=), SequenceEntry(index=1, type=))) + PyTreeAccessor(*['c'][0], ...), + PyTreeAccessor(*['c'][1], ...) ) }, 'value': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)} @@ -2540,6 +2542,12 @@ def treespec_paths(treespec: PyTreeSpec) -> list[tuple[Any, ...]]: """Return a list of paths to the leaves of a treespec. See also :func:`tree_flatten_with_path`, :func:`tree_paths`, and :meth:`PyTreeSpec.paths`. + + >>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) + >>> treespec + PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) + >>> treespec_paths(treespec) + [('a', 0), ('a', 1, 0), ('a', 1, 1), ('b',), ('c', 0)] """ return treespec.paths() @@ -2547,8 +2555,24 @@ def treespec_paths(treespec: PyTreeSpec) -> list[tuple[Any, ...]]: def treespec_accessors(treespec: PyTreeSpec) -> list[PyTreeAccessor]: """Return a list of accessors to the leaves of a treespec. - See also :func:`tree_flatten_with_accessor`, :func:`tree_accessors` and - :meth:`PyTreeSpec.accessors`. + See also :func:`tree_flatten_with_accessor`, :func:`tree_accessors`, + and :meth:`PyTreeSpec.accessors`. + + >>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) + >>> treespec + PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) + >>> treespec_accessors(treespec) # doctest: +IGNORE_WHITESPACE,ELLIPSIS + [ + PyTreeAccessor(*['a'][0], ...), + PyTreeAccessor(*['a'][1][0], ...), + PyTreeAccessor(*['a'][1][1], ...), + PyTreeAccessor(*['b'], ...), + PyTreeAccessor(*['c'][0], ...) + ] + >>> treespec_accessors(treespec_leaf()) + [PyTreeAccessor(*, ())] + >>> treespec_accessors(treespec_none()) + [] """ return treespec.accessors() @@ -2558,6 +2582,12 @@ def treespec_entries(treespec: PyTreeSpec) -> list[Any]: See also :func:`treespec_entry`, :func:`treespec_paths`, :func:`treespec_children`, and :meth:`PyTreeSpec.entries`. + + >>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) + >>> treespec + PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) + >>> treespec_entries(treespec) + ['a', 'b', 'c'] """ return treespec.entries() @@ -2574,7 +2604,13 @@ def treespec_children(treespec: PyTreeSpec) -> list[PyTreeSpec]: """Return a list of treespecs for the children of a treespec. See also :func:`treespec_child`, :func:`treespec_paths`, :func:`treespec_entries`, - and :meth:`PyTreeSpec.children`. + :func:`treespec_one_level`, and :meth:`PyTreeSpec.children`. + + >>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) + >>> treespec + PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) + >>> treespec_children(treespec) + [PyTreeSpec((*, [*, *])), PyTreeSpec(*), PyTreeSpec((*, None))] """ return treespec.children() @@ -2587,6 +2623,20 @@ def treespec_child(treespec: PyTreeSpec, index: int) -> PyTreeSpec: return treespec.child(index) +def treespec_one_level(treespec: PyTreeSpec) -> PyTreeSpec | None: + """Return the one-level tree structure of the treespec or :data:`None` if the treespec is a leaf. + + See also :func:`treespec_children`, :func:`treespec_is_one_level`, and :meth:`PyTreeSpec.one_level`. + + >>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) + >>> treespec + PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) + >>> treespec_one_level(treespec) + PyTreeSpec({'a': *, 'b': *, 'c': *}) + """ + return treespec.one_level() + + def treespec_transform( treespec: PyTreeSpec, f_node: Callable[[PyTreeSpec], PyTreeSpec] | None = None, @@ -2713,6 +2763,28 @@ def treespec_is_strict_leaf(treespec: PyTreeSpec) -> bool: return treespec.num_nodes == 1 and treespec.num_leaves == 1 +def treespec_is_one_level(treespec: PyTreeSpec) -> bool: + """Return whether the treespec is a one-level tree structure. + + See also :func:`treespec_is_leaf`, :func:`treespec_one_level`, and :meth:`PyTreeSpec.is_one_level`. + + >>> treespec_is_one_level(tree_structure(1)) + False + >>> treespec_is_one_level(tree_structure((1, 2))) + True + >>> treespec_is_one_level(tree_structure({'a': 1, 'b': 2, 'c': 3})) + True + >>> treespec_is_one_level(tree_structure({'a': 1, 'b': (2, 3), 'c': 4})) + False + >>> treespec_is_one_level(tree_structure(None)) + True + """ + return ( + treespec.num_nodes == treespec.num_children + 1 + and treespec.num_leaves == treespec.num_children + ) + + def treespec_is_prefix( treespec: PyTreeSpec, other_treespec: PyTreeSpec, diff --git a/pyproject.toml b/pyproject.toml index 404ed359..5cd86cca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,7 @@ docs = [ ] benchmark = [ "jax[cpu] >= 0.4.6, < 0.5.0a0", - "torch >= 2.0, < 2.4.0a0", + "torch >= 2.0, < 2.6.0a0", "torchvision", "dm-tree >= 0.1, < 0.2.0a0", "pandas", diff --git a/src/optree.cpp b/src/optree.cpp index 67c34432..baab205d 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -228,6 +228,16 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] &PyTreeSpec::Child, "Return the treespec for the child at the given index.", py::arg("index")) + .def( + "one_level", + [](const PyTreeSpec& t) -> std::optional> { + if (t.IsLeaf()) [[unlikely]] { + return std::nullopt; + } + return t.GetOneLevel(); + }, + "Return the one-level structure of the root node. Return None if the root node " + "represents a leaf.") .def_property_readonly("num_leaves", &PyTreeSpec::GetNumLeaves, "Number of leaves in the tree.") @@ -237,7 +247,7 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] "Note that a leaf is also a node but has no children.") .def_property_readonly("num_children", &PyTreeSpec::GetNumChildren, - "Number of children in the current node. " + "Number of children of the root node. " "Note that a leaf is also a node but has no children.") .def_property_readonly( "none_is_leaf", @@ -252,13 +262,16 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] .def_property_readonly( "type", [](const PyTreeSpec& t) -> py::object { return t.GetType(); }, - "The type of the current node. Return None if the current node is a leaf.") - .def_property_readonly("kind", &PyTreeSpec::GetPyTreeKind, "The kind of the current node.") + "The type of the root node. Return None if the root node is a leaf.") + .def_property_readonly("kind", &PyTreeSpec::GetPyTreeKind, "The kind of the root node.") .def("is_leaf", &PyTreeSpec::IsLeaf, - "Test whether the current node is a leaf.", + "Test whether the treespec represents a leaf.", py::kw_only(), py::arg("strict") = true) + .def("is_one_level", + &PyTreeSpec::IsOneLevel, + "Test whether the treespec represents a one-level tree.") .def("is_prefix", &PyTreeSpec::IsPrefix, "Test whether this treespec is a prefix of the given treespec.", diff --git a/src/treespec/treespec.cpp b/src/treespec/treespec.cpp index a022f1d0..74bde9d9 100644 --- a/src/treespec/treespec.cpp +++ b/src/treespec/treespec.cpp @@ -458,29 +458,9 @@ std::unique_ptr PyTreeSpec::Transform(const std::optional(*this); } - const auto create_nodespec = [this](const Node& node) -> std::unique_ptr { - auto nodespec = std::make_unique(); - for (ssize_t i = 0; i < node.arity; ++i) { - nodespec->m_traversal.emplace_back(Node{ - .kind = PyTreeKind::Leaf, - .arity = 0, - .num_leaves = 1, - .num_nodes = 1, - }); - } - auto& root = nodespec->m_traversal.emplace_back(node); - root.num_leaves = (node.kind == PyTreeKind::Leaf ? 1 : node.arity); - root.num_nodes = node.arity + 1; - nodespec->m_none_is_leaf = m_none_is_leaf; - nodespec->m_namespace = m_namespace; - nodespec->m_traversal.shrink_to_fit(); - PYTREESPEC_SANITY_CHECK(*nodespec); - return nodespec; - }; - const auto transform = - [&create_nodespec, &f_node, &f_leaf](const Node& node) -> std::unique_ptr { - auto nodespec = create_nodespec(node); + [this, &f_node, &f_leaf](const Node& node) -> std::unique_ptr { + auto nodespec = GetOneLevel(node); const auto& func = (node.kind == PyTreeKind::Leaf ? f_leaf : f_node); if (!func) [[likely]] { @@ -491,7 +471,7 @@ std::unique_ptr PyTreeSpec::Transform(const std::optional(out)) [[unlikely]] { std::ostringstream oss{}; oss << "Expected the PyTreeSpec transform function returns a PyTreeSpec, got " - << PyRepr(out) << " (input: " << create_nodespec(node)->ToString() << ")."; + << PyRepr(out) << " (input: " << GetOneLevel(node)->ToString() << ")."; throw py::type_error(oss.str()); } return std::make_unique(thread_safe_cast(out)); @@ -510,7 +490,7 @@ std::unique_ptr PyTreeSpec::Transform(const std::optionalToString() - << " (input: " << create_nodespec(node)->ToString() << ")."; + << " (input: " << GetOneLevel(node)->ToString() << ")."; throw py::value_error(oss.str()); } if (!transformed->m_namespace.empty()) [[unlikely]] { @@ -531,14 +511,14 @@ std::unique_ptr PyTreeSpec::Transform(const std::optionalToString() - << " (input: " << create_nodespec(node)->ToString() << ")."; + << " (input: " << GetOneLevel(node)->ToString() << ")."; throw py::value_error(oss.str()); } if (transformed->GetNumNodes() != node.arity + 1) [[unlikely]] { std::ostringstream oss{}; oss << "Expected the PyTreeSpec transform function returns an one-level PyTreeSpec " "as the input, got " - << transformed->ToString() << " (input: " << create_nodespec(node)->ToString() + << transformed->ToString() << " (input: " << GetOneLevel(node)->ToString() << ")."; throw py::value_error(oss.str()); } @@ -1012,4 +992,29 @@ py::object PyTreeSpec::GetType(const std::optional& node) const { } } +std::unique_ptr PyTreeSpec::GetOneLevel(const std::optional& node) const { + if (!node) [[likely]] { + PYTREESPEC_SANITY_CHECK(*this); + } + + const Node& n = node.value_or(m_traversal.back()); + auto out = std::make_unique(); + for (ssize_t i = 0; i < n.arity; ++i) { + out->m_traversal.emplace_back(Node{ + .kind = PyTreeKind::Leaf, + .arity = 0, + .num_leaves = 1, + .num_nodes = 1, + }); + } + auto& root = out->m_traversal.emplace_back(n); + root.num_leaves = (n.kind == PyTreeKind::Leaf ? 1 : n.arity); + root.num_nodes = n.arity + 1; + out->m_none_is_leaf = m_none_is_leaf; + out->m_namespace = m_namespace; + out->m_traversal.shrink_to_fit(); + PYTREESPEC_SANITY_CHECK(*out); + return out; +} + } // namespace optree diff --git a/tests/test_treespec.py b/tests/test_treespec.py index 10018a87..487dd9bd 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -923,6 +923,67 @@ def test_treespec_child( ] +@parametrize( + tree=TREES, + none_is_leaf=[False, True], + namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], +) +def test_treespec_one_level( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + treespec = optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace) + if treespec.type is None: + assert treespec.is_leaf() + assert optree.treespec_one_level(treespec) is None + assert optree.treespec_children(treespec) == [] + assert treespec.num_children == 0 + else: + one_level = optree.treespec_one_level(treespec) + counter = itertools.count() + expected_treespec = optree.tree_structure( + tree, + is_leaf=lambda x: next(counter) > 0, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + num_children = treespec.num_children + assert not treespec.is_leaf() + assert not one_level.is_leaf() + assert not expected_treespec.is_leaf() + assert one_level == expected_treespec + assert optree.treespec_one_level(one_level) == one_level + assert optree.treespec_one_level(expected_treespec) == expected_treespec + assert one_level.num_nodes == num_children + 1 + assert one_level.num_leaves == num_children + assert one_level.num_children == num_children + assert len(one_level) == num_children + assert optree.treespec_entries(one_level) == optree.treespec_entries(treespec) + assert all(optree.treespec_child(one_level, i).is_leaf() for i in range(num_children)) + assert all(child.is_leaf() for child in optree.treespec_children(one_level)) + assert optree.treespec_is_prefix(one_level, treespec) + assert optree.treespec_is_suffix(treespec, one_level) + assert ( + optree.treespec_from_collection( + optree.tree_unflatten(one_level, treespec.children()), + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + == treespec + ) + it = iter(treespec.children()) + assert optree.treespec_transform(one_level, None, lambda _: next(it)) == treespec + + def test_treespec_transform(): treespec = optree.tree_structure(((1, 2, 3), (4,))) assert optree.treespec_transform(treespec) == treespec @@ -1130,6 +1191,50 @@ def test_treespec_is_leaf(): assert optree.tree_structure([]).is_leaf(strict=False) +@parametrize( + tree=TREES, + none_is_leaf=[False, True], + namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], +) +def test_treespec_is_one_level( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + treespec = optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace) + if treespec.type is None: + assert treespec.is_leaf() + assert optree.treespec_one_level(treespec) is None + assert not optree.treespec_is_one_level(treespec) + else: + one_level = optree.treespec_one_level(treespec) + counter = itertools.count() + expected_treespec = optree.tree_structure( + tree, + is_leaf=lambda x: next(counter) > 0, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + assert not treespec.is_leaf() + assert not one_level.is_leaf() + assert not expected_treespec.is_leaf() + assert one_level == expected_treespec + assert optree.treespec_one_level(one_level) == one_level + assert optree.treespec_one_level(expected_treespec) == expected_treespec + assert optree.treespec_is_one_level(one_level) + assert optree.treespec_is_one_level(expected_treespec) + assert optree.treespec_is_one_level(treespec) == (treespec == one_level) + assert optree.treespec_is_one_level(treespec) == (treespec == expected_treespec) + + @parametrize( namespace=['', 'undefined', 'namespace'], )