From 8ad38f5e89875d1cdb41f67e49332f242f7dd59c Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 6 Dec 2024 15:18:36 +0800 Subject: [PATCH 1/7] feat(treespec): add method `PyTreeSpec.transform` --- docs/source/ops.rst | 2 + include/optree/treespec.h | 6 ++ optree/_C.pyi | 5 ++ optree/__init__.py | 2 + optree/ops.py | 25 ++++++ src/optree.cpp | 6 ++ src/treespec/constructor.cpp | 28 +++---- src/treespec/treespec.cpp | 142 +++++++++++++++++++++++++++++++++++ 8 files changed, 202 insertions(+), 14 deletions(-) diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 759555e8..2a7b599b 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -123,6 +123,7 @@ PyTreeSpec Functions treespec_entry treespec_children treespec_child + treespec_transform treespec_is_leaf treespec_is_strict_leaf treespec_is_prefix @@ -145,6 +146,7 @@ PyTreeSpec Functions .. autofunction:: treespec_entry .. autofunction:: treespec_children .. autofunction:: treespec_child +.. autofunction:: treespec_transform .. autofunction:: treespec_is_leaf .. autofunction:: treespec_is_strict_leaf .. autofunction:: treespec_is_prefix diff --git a/include/optree/treespec.h b/include/optree/treespec.h index ef10adbc..60407bd2 100644 --- a/include/optree/treespec.h +++ b/include/optree/treespec.h @@ -128,6 +128,12 @@ class PyTreeSpec { [[nodiscard]] std::unique_ptr BroadcastToCommonSuffix( const PyTreeSpec &other) const; + // Transform a PyTreeSpec by applying `f_node(nodespec)` to nodes and `f_leaf(leafspec)` to + // leaves. + [[nodiscard]] std::unique_ptr Transform( + const std::optional &f_node, + const std::optional &f_leaf) const; + // Compose two PyTreeSpecs, replacing the leaves of this tree with copies of `inner`. [[nodiscard]] std::unique_ptr Compose(const PyTreeSpec &inner_treespec) const; diff --git a/optree/_C.pyi b/optree/_C.pyi index 2628d5b1..32d1e7c9 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -111,6 +111,11 @@ class PyTreeSpec: def unflatten(self, leaves: Iterable[T]) -> PyTree[T]: ... def flatten_up_to(self, full_tree: PyTree[T]) -> list[PyTree[T]]: ... def broadcast_to_common_suffix(self, other: Self) -> Self: ... + def transform( + self, + f_node: Callable[[Self], Self] | None = None, + f_leaf: Callable[[Self], Self] | None = None, + ) -> Self: ... def compose(self, inner_treespec: Self) -> Self: ... def walk( self, diff --git a/optree/__init__.py b/optree/__init__.py index 87941736..7bb07e09 100644 --- a/optree/__init__.py +++ b/optree/__init__.py @@ -90,6 +90,7 @@ treespec_ordereddict, treespec_paths, treespec_structseq, + treespec_transform, treespec_tuple, ) from optree.registry import ( @@ -170,6 +171,7 @@ 'treespec_entry', 'treespec_children', 'treespec_child', + 'treespec_transform', 'treespec_is_leaf', 'treespec_is_strict_leaf', 'treespec_is_prefix', diff --git a/optree/ops.py b/optree/ops.py index 1df1c0d4..065bb2e6 100644 --- a/optree/ops.py +++ b/optree/ops.py @@ -94,6 +94,7 @@ 'treespec_entry', 'treespec_children', 'treespec_child', + 'treespec_transform', 'treespec_is_leaf', 'treespec_is_strict_leaf', 'treespec_is_prefix', @@ -2586,6 +2587,30 @@ def treespec_child(treespec: PyTreeSpec, index: int) -> PyTreeSpec: return treespec.child(index) +def treespec_transform( + treespec: PyTreeSpec, + f_node: Callable[[PyTreeSpec], PyTreeSpec] | None = None, + f_leaf: Callable[[PyTreeSpec], PyTreeSpec] | None = None, +) -> PyTreeSpec: + """Transform a treespec by applying functions to its nodes and leaves. + + See also :func:`treespec_children`, :func:`treespec_is_leaf`, and :meth:`PyTreeSpec.transform`. + + >>> treespec = tree_structure({'a': (0, [1, 2]), 'b': 3, 'c': (4, None)}) + >>> treespec + PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) + >>> treespec_transform(treespec, lambda spec: treespec_dict(zip(spec.entries(), spec.children()))) + PyTreeSpec({'a': {0: *, 1: {0: *, 1: *}}, 'b': *, 'c': {0: *, 1: {}}}) + >>> treespec_transform(treespec, lambda spec: treespec_tuple(spec.children())) + PyTreeSpec(((*, (*, *)), *, (*, ()))) + >>> treespec_transform(treespec, lambda spec: treespec_list(spec.children()) if spec.type is tuple else spec) + PyTreeSpec({'a': [*, [*, *]], 'b': *, 'c': [*, None]}) + >>> treespec_transform(treespec, None, lambda spec: tree_structure((1, [2]))) + PyTreeSpec({'a': ((*, [*]), [(*, [*]), (*, [*])]), 'b': (*, [*]), 'c': ((*, [*]), None)}) + """ + return treespec.transform(f_node, f_leaf) + + def treespec_is_leaf(treespec: PyTreeSpec, *, strict: bool = True) -> bool: """Return whether the treespec is a leaf that has no children. diff --git a/src/optree.cpp b/src/optree.cpp index 7fe4e90b..756d6973 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -200,6 +200,12 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] &PyTreeSpec::BroadcastToCommonSuffix, "Broadcast to the common suffix of this treespec and other treespec.", py::arg("other")) + .def("transform", + &PyTreeSpec::Transform, + "Transform the pytree structure by applying ``f_node(nodespec)`` at nodes and " + "``f_leaf(leafspec)`` at leaves.", + py::arg("f_node") = std::nullopt, + py::arg("f_leaf") = std::nullopt) .def("compose", &PyTreeSpec::Compose, "Compose two treespecs. Constructs the inner treespec as a subtree at each leaf node.", diff --git a/src/treespec/constructor.cpp b/src/treespec/constructor.cpp index 903c417a..a023e257 100644 --- a/src/treespec/constructor.cpp +++ b/src/treespec/constructor.cpp @@ -72,9 +72,9 @@ template Node node; node.kind = PyTreeTypeRegistry::GetKind(handle, node.custom, registry_namespace); - const auto verify_children = [&handle, &node](const std::vector& children, - std::vector& treespecs, - std::string& register_namespace) -> void { + const auto verify_children = [&handle, &node, ®istry_namespace]( + const std::vector& children, + std::vector& treespecs) -> void { for (const py::object& child : children) { if (!py::isinstance(child)) [[unlikely]] { std::ostringstream oss{}; @@ -106,16 +106,16 @@ template } } if (!common_registry_namespace.empty()) [[likely]] { - if (register_namespace.empty()) [[likely]] { - register_namespace = common_registry_namespace; - } else if (register_namespace != common_registry_namespace) [[unlikely]] { + if (registry_namespace.empty()) [[likely]] { + registry_namespace = common_registry_namespace; + } else if (registry_namespace != common_registry_namespace) [[unlikely]] { std::ostringstream oss{}; - oss << "Expected treespec(s) with namespace " << PyRepr(register_namespace) + oss << "Expected treespec(s) with namespace " << PyRepr(registry_namespace) << ", got " << PyRepr(common_registry_namespace) << "."; throw py::value_error(oss.str()); } } else if (node.kind != PyTreeKind::Custom) [[likely]] { - register_namespace = ""; + registry_namespace = ""; } }; @@ -143,7 +143,7 @@ template for (ssize_t i = 0; i < node.arity; ++i) { children.emplace_back(TupleGetItem(handle, i)); } - verify_children(children, treespecs, registry_namespace); + verify_children(children, treespecs); break; } @@ -155,7 +155,7 @@ template children.emplace_back(ListGetItem(handle, i)); } } - verify_children(children, treespecs, registry_namespace); + verify_children(children, treespecs); break; } @@ -178,7 +178,7 @@ template children.emplace_back(DictGetItem(dict, key)); } } - verify_children(children, treespecs, registry_namespace); + verify_children(children, treespecs); if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] { const scoped_critical_section cs{handle}; node.node_data = py::make_tuple(py::getattr(handle, Py_Get_ID(default_factory)), @@ -197,7 +197,7 @@ template for (ssize_t i = 0; i < node.arity; ++i) { children.emplace_back(TupleGetItem(tuple, i)); } - verify_children(children, treespecs, registry_namespace); + verify_children(children, treespecs); break; } @@ -209,7 +209,7 @@ template for (ssize_t i = 0; i < node.arity; ++i) { children.emplace_back(ListGetItem(list, i)); } - verify_children(children, treespecs, registry_namespace); + verify_children(children, treespecs); break; } @@ -235,7 +235,7 @@ template children.emplace_back(py::reinterpret_borrow(child)); } } - verify_children(children, treespecs, registry_namespace); + verify_children(children, treespecs); if (num_out == 3) [[likely]] { const py::object node_entries = TupleGetItem(out, 2); if (!node_entries.is_none()) [[likely]] { diff --git a/src/treespec/treespec.cpp b/src/treespec/treespec.cpp index d3313d85..bc19de59 100644 --- a/src/treespec/treespec.cpp +++ b/src/treespec/treespec.cpp @@ -20,6 +20,7 @@ limitations under the License. #include // std::unique_ptr, std::make_unique #include // std::optional #include // std::ostringstream +#include // std::string #include // std::tuple #include // std::move #include // std::vector @@ -448,6 +449,147 @@ std::unique_ptr PyTreeSpec::BroadcastToCommonSuffix(const PyTreeSpec return treespec; } +// NOLINTNEXTLINE[readability-function-cognitive-complexity] +std::unique_ptr PyTreeSpec::Transform(const std::optional& f_node, + const std::optional& f_leaf) const { + PYTREESPEC_SANITY_CHECK(*this); + + if (!f_node && !f_leaf) [[unlikely]] { + return std::make_unique(*this); + } + + const auto create_nodespec = [this](const Node& node) -> std::unique_ptr { + auto nodespec = std::make_unique(); + for (ssize_t i = 0; i < node.arity; ++i) { + nodespec->m_traversal.emplace_back(Node{ + .kind = PyTreeKind::Leaf, + .arity = 0, + .num_leaves = 1, + .num_nodes = 1, + }); + } + auto& root = nodespec->m_traversal.emplace_back(node); + root.num_leaves = (node.kind == PyTreeKind::Leaf ? 1 : node.arity); + root.num_nodes = node.arity + 1; + nodespec->m_none_is_leaf = m_none_is_leaf; + nodespec->m_namespace = m_namespace; + nodespec->m_traversal.shrink_to_fit(); + PYTREESPEC_SANITY_CHECK(*nodespec); + return nodespec; + }; + + const auto transform = + [&create_nodespec, &f_node, &f_leaf](const Node& node) -> std::unique_ptr { + auto nodespec = create_nodespec(node); + + const auto& func = (node.kind == PyTreeKind::Leaf ? f_leaf : f_node); + if (!func) [[likely]] { + return nodespec; + } + + const py::object out = EVALUATE_WITH_LOCK_HELD((*func)(std::move(nodespec)), *func); + if (!py::isinstance(out)) [[unlikely]] { + std::ostringstream oss{}; + oss << "Expected the PyTreeSpec transform function returns a PyTreeSpec, got " + << PyRepr(out) << " (input: " << create_nodespec(node)->ToString() << ")."; + throw py::type_error(oss.str()); + } + return std::make_unique(thread_safe_cast(out)); + }; + + auto treespec = std::make_unique(); + std::string common_registry_namespace = m_namespace; + ssize_t num_extra_leaves = 0; + ssize_t num_extra_nodes = 0; + auto pending_num_leaves_nodes = reserved_vector>(4); + for (const Node& node : m_traversal) { + auto transformed = transform(node); + if (transformed->m_none_is_leaf != m_none_is_leaf) [[unlikely]] { + std::ostringstream oss{}; + oss << "Expected the PyTreeSpec transform function returns " + "a PyTreeSpec with the same value of " + << (m_none_is_leaf ? "`none_is_leaf=True`" : "`none_is_leaf=False`") + << " as the input, got " << transformed->ToString() + << " (input: " << create_nodespec(node)->ToString() << ")."; + throw py::value_error(oss.str()); + } + if (!transformed->m_namespace.empty()) [[unlikely]] { + if (common_registry_namespace.empty()) [[likely]] { + common_registry_namespace = transformed->m_namespace; + } else if (transformed->m_namespace != common_registry_namespace) [[unlikely]] { + std::ostringstream oss{}; + oss << "Expected the PyTreeSpec transform function returns " + "a PyTreeSpec with namespace " + << PyRepr(common_registry_namespace) << ", got " + << PyRepr(transformed->m_namespace) << "."; + throw py::value_error(oss.str()); + } + } + if (node.kind != PyTreeKind::Leaf) [[likely]] { + if (transformed->GetNumLeaves() != node.arity) [[unlikely]] { + std::ostringstream oss{}; + oss << "Expected the PyTreeSpec transform function returns " + "a PyTreeSpec with the same number of arity as the input (" + << node.arity << "), got " << transformed->ToString() + << " (input: " << create_nodespec(node)->ToString() << ")."; + throw py::value_error(oss.str()); + } + if (transformed->GetNumNodes() != node.arity + 1) [[unlikely]] { + std::ostringstream oss{}; + oss << "Expected the PyTreeSpec transform function returns an one-level PyTreeSpec " + "as the input, got " + << transformed->ToString() << " (input: " << create_nodespec(node)->ToString() + << ")."; + throw py::value_error(oss.str()); + } + auto& subroot = treespec->m_traversal.emplace_back(transformed->m_traversal.back()); + EXPECT_GE(py::ssize_t_cast(pending_num_leaves_nodes.size()), + node.arity, + "PyTreeSpec::Transform() walked off start of array."); + subroot.num_leaves = 0; + subroot.num_nodes = 1; + for (ssize_t i = 0; i < node.arity; ++i) { + const auto& [num_leaves, num_nodes] = pending_num_leaves_nodes.back(); + pending_num_leaves_nodes.pop_back(); + subroot.num_leaves += num_leaves; + subroot.num_nodes += num_nodes; + } + pending_num_leaves_nodes.emplace_back(subroot.num_leaves, subroot.num_nodes); + } else [[unlikely]] { + std::copy(transformed->m_traversal.cbegin(), + transformed->m_traversal.cend(), + std::back_inserter(treespec->m_traversal)); + const ssize_t num_leaves = transformed->GetNumLeaves(); + const ssize_t num_nodes = transformed->GetNumNodes(); + num_extra_leaves += num_leaves - 1; + num_extra_nodes += num_nodes - 1; + pending_num_leaves_nodes.emplace_back(num_leaves, num_nodes); + } + } + EXPECT_EQ(pending_num_leaves_nodes.size(), + 1, + "PyTreeSpec::Transform() did not yield a singleton."); + + const auto& root = treespec->m_traversal.back(); + EXPECT_EQ(root.num_leaves, + GetNumLeaves() + num_extra_leaves, + "Number of transformed tree leaves mismatch."); + EXPECT_EQ(root.num_nodes, + GetNumNodes() + num_extra_nodes, + "Number of transformed tree nodes mismatch."); + EXPECT_EQ(root.num_leaves, + treespec->GetNumLeaves(), + "Number of transformed tree leaves mismatch."); + EXPECT_EQ(root.num_nodes, + treespec->GetNumNodes(), + "Number of transformed tree nodes mismatch."); + treespec->m_none_is_leaf = m_none_is_leaf; + treespec->m_namespace = m_namespace; + treespec->m_traversal.shrink_to_fit(); + PYTREESPEC_SANITY_CHECK(*treespec); + return treespec; +} + std::unique_ptr PyTreeSpec::Compose(const PyTreeSpec& inner_treespec) const { PYTREESPEC_SANITY_CHECK(*this); PYTREESPEC_SANITY_CHECK(inner_treespec); From 85e82caeee45f86d451a2f599f44740746197a97 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 6 Dec 2024 17:33:55 +0800 Subject: [PATCH 2/7] test: add tests for `PyTreeSpec.transform` --- tests/test_treespec.py | 85 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 75 insertions(+), 10 deletions(-) diff --git a/tests/test_treespec.py b/tests/test_treespec.py index 6dd7787d..d9eee7e2 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -591,24 +591,36 @@ def test_treespec_compose_children( namespace=namespace, ) composed_treespec = treespec.compose(inner_treespec) + transformed_treespec = treespec.transform(None, lambda _: inner_treespec) expected_leaves = treespec.num_leaves * inner_treespec.num_leaves assert composed_treespec.num_leaves == treespec.num_leaves * inner_treespec.num_leaves + assert transformed_treespec.num_leaves == expected_leaves expected_nodes = (treespec.num_nodes - treespec.num_leaves) + ( inner_treespec.num_nodes * treespec.num_leaves ) assert composed_treespec.num_nodes == expected_nodes + assert transformed_treespec.num_nodes == expected_nodes leaves = list(range(expected_leaves)) composed = optree.tree_unflatten(composed_treespec, leaves) - assert leaves == optree.tree_leaves( + transformed = optree.tree_unflatten(transformed_treespec, leaves) + assert composed == transformed + + if 'FlatCache' in str(treespec): + return + + assert (leaves, composed_treespec) == optree.tree_flatten( composed, none_is_leaf=none_is_leaf, namespace=namespace, ) - - if 'FlatCache' in str(treespec): - return + assert (leaves, transformed_treespec) == optree.tree_flatten( + transformed, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) assert composed_treespec == expected_treespec + assert transformed_treespec == expected_treespec stack = [(composed_treespec.children(), expected_treespec.children())] while stack: @@ -617,12 +629,6 @@ def test_treespec_compose_children( assert composed_child == expected_child stack.append((composed_child.children(), expected_child.children())) - assert composed_treespec == optree.tree_structure( - composed, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - if treespec == expected_treespec: assert not (treespec != expected_treespec) assert not (treespec < expected_treespec) @@ -917,6 +923,65 @@ def test_treespec_child( ] +def test_treespec_transform(): + treespec = optree.tree_structure(((1, 2, 3), (4,))) + assert optree.treespec_transform(treespec) == treespec + assert optree.treespec_transform(treespec) is not treespec + assert optree.treespec_transform( + treespec, + None, + lambda _: optree.tree_structure((1, [2])), + ) == optree.tree_structure((((0, [1]), (2, [3]), (4, [5])), ((6, [7]),))) + assert optree.treespec_transform( + treespec, + lambda spec: optree.treespec_list(spec.children()), + ) == optree.tree_structure([[1, 2, 3], [4]]) + assert optree.treespec_transform( + treespec, + lambda spec: optree.treespec_dict(zip('abcd', spec.children())), + ) == optree.tree_structure({'a': {'a': 0, 'b': 1, 'c': 2}, 'b': {'a': 3}}) + assert optree.treespec_transform( + treespec, + lambda spec: optree.treespec_dict(zip('abcd', spec.children())), + lambda spec: optree.tree_structure([0, None, 1]), + ) == optree.tree_structure( + {'a': {'a': [0, None, 1], 'b': [2, None, 3], 'c': [4, None, 5]}, 'b': {'a': [6, None, 7]}}, + ) + + with pytest.raises( + TypeError, + match=re.escape('Expected the PyTreeSpec transform function returns a PyTreeSpec'), + ): + optree.treespec_transform(treespec, lambda _: None) + with pytest.raises( + TypeError, + match=re.escape('Expected the PyTreeSpec transform function returns a PyTreeSpec'), + ): + optree.treespec_transform(treespec, None, lambda _: None) + with pytest.raises( + ValueError, + match=( + r'Expected the PyTreeSpec transform function returns a PyTreeSpec ' + r'with the same value of `none_is_leaf=\w+` as the input' + ), + ): + optree.treespec_transform( + treespec, + lambda spec: optree.treespec_list( + [optree.treespec_leaf(none_is_leaf=True)] * spec.num_children, + none_is_leaf=True, + ), + ) + with pytest.raises( + ValueError, + match=( + r'Expected the PyTreeSpec transform function returns a PyTreeSpec ' + r'with the same number of arity as the input' + ), + ): + optree.treespec_transform(treespec, lambda _: optree.tree_structure([0, 1])) + + @parametrize( tree=TREES, none_is_leaf=[False, True], From 5d0b4a4e24618a02e650087f1874a9a588ab06dd Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 7 Dec 2024 16:41:14 +0800 Subject: [PATCH 3/7] test: add more tests for `PyTreeSpec.transform` --- src/treespec/treespec.cpp | 2 +- tests/test_treespec.py | 46 ++++++++++++++++++++++++++++++++++----- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/src/treespec/treespec.cpp b/src/treespec/treespec.cpp index bc19de59..7aa5d5bd 100644 --- a/src/treespec/treespec.cpp +++ b/src/treespec/treespec.cpp @@ -584,7 +584,7 @@ std::unique_ptr PyTreeSpec::Transform(const std::optionalGetNumNodes(), "Number of transformed tree nodes mismatch."); treespec->m_none_is_leaf = m_none_is_leaf; - treespec->m_namespace = m_namespace; + treespec->m_namespace = common_registry_namespace; treespec->m_traversal.shrink_to_fit(); PYTREESPEC_SANITY_CHECK(*treespec); return treespec; diff --git a/tests/test_treespec.py b/tests/test_treespec.py index d9eee7e2..10018a87 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -947,6 +947,24 @@ def test_treespec_transform(): ) == optree.tree_structure( {'a': {'a': [0, None, 1], 'b': [2, None, 3], 'c': [4, None, 5]}, 'b': {'a': [6, None, 7]}}, ) + namespaced_treespec = optree.tree_structure( + MyAnotherDict({1: MyAnotherDict({2: 1, 1: 2, 0: 3}), 0: MyAnotherDict({0: 4})}), + namespace='namespace', + ) + assert ( + optree.treespec_transform( + treespec, + lambda spec: optree.tree_structure( + MyAnotherDict(zip(spec.entries(), spec.children())), + namespace='namespace', + ), + ) + == namespaced_treespec + ) + assert optree.treespec_transform( + namespaced_treespec, + lambda spec: optree.treespec_list(spec.children()), + ) == optree.tree_structure([[1, 2, 3], [4]]) with pytest.raises( TypeError, @@ -961,8 +979,8 @@ def test_treespec_transform(): with pytest.raises( ValueError, match=( - r'Expected the PyTreeSpec transform function returns a PyTreeSpec ' - r'with the same value of `none_is_leaf=\w+` as the input' + r'Expected the PyTreeSpec transform function returns ' + r'a PyTreeSpec with the same value of `none_is_leaf=\w+` as the input' ), ): optree.treespec_transform( @@ -972,14 +990,32 @@ def test_treespec_transform(): none_is_leaf=True, ), ) + with pytest.raises(ValueError, match=r'Expected treespec\(s\) with namespace .*, got .*\.'): + + def fn(spec): + with optree.dict_insertion_ordered(True, namespace='undefined'): + return optree.treespec_dict(zip('abcd', spec.children()), namespace='undefined') + + optree.treespec_transform(namespaced_treespec, fn) with pytest.raises( ValueError, - match=( - r'Expected the PyTreeSpec transform function returns a PyTreeSpec ' - r'with the same number of arity as the input' + match=re.escape( + 'Expected the PyTreeSpec transform function returns ' + 'a PyTreeSpec with the same number of arity as the input', ), ): optree.treespec_transform(treespec, lambda _: optree.tree_structure([0, 1])) + with pytest.raises( + ValueError, + match=re.escape( + 'Expected the PyTreeSpec transform function returns ' + 'an one-level PyTreeSpec as the input', + ), + ): + optree.treespec_transform( + treespec, + lambda spec: optree.tree_structure([None] + [0] * spec.num_children), + ) @parametrize( From db7b462f54b8b75663d8824eaac97c13b25de9d3 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 10 Dec 2024 00:00:36 +0800 Subject: [PATCH 4/7] docs(CHANGELOG): update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f06d29c..babea005 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Add method `PyTreeSpec.transform` by [@XuehaiPan](https://github.com/XuehaiPan) in [#177](https://github.com/metaopt/optree/pull/177). ### Changed From 393e53d9888142175f4fa7c683f820bf8ecd9520 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 10 Dec 2024 15:39:42 +0800 Subject: [PATCH 5/7] test: add more examples in doctest --- include/optree/treespec.h | 4 ++-- optree/ops.py | 29 +++++++++++++++++++++++++++-- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/include/optree/treespec.h b/include/optree/treespec.h index 60407bd2..9b2e7ec4 100644 --- a/include/optree/treespec.h +++ b/include/optree/treespec.h @@ -131,8 +131,8 @@ class PyTreeSpec { // Transform a PyTreeSpec by applying `f_node(nodespec)` to nodes and `f_leaf(leafspec)` to // leaves. [[nodiscard]] std::unique_ptr Transform( - const std::optional &f_node, - const std::optional &f_leaf) const; + const std::optional &f_node = std::nullopt, + const std::optional &f_leaf = std::nullopt) const; // Compose two PyTreeSpecs, replacing the leaves of this tree with copies of `inner`. [[nodiscard]] std::unique_ptr Compose(const PyTreeSpec &inner_treespec) const; diff --git a/optree/ops.py b/optree/ops.py index 065bb2e6..f9315afb 100644 --- a/optree/ops.py +++ b/optree/ops.py @@ -2596,14 +2596,39 @@ def treespec_transform( See also :func:`treespec_children`, :func:`treespec_is_leaf`, and :meth:`PyTreeSpec.transform`. - >>> treespec = tree_structure({'a': (0, [1, 2]), 'b': 3, 'c': (4, None)}) + >>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) >>> treespec PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) >>> treespec_transform(treespec, lambda spec: treespec_dict(zip(spec.entries(), spec.children()))) PyTreeSpec({'a': {0: *, 1: {0: *, 1: *}}, 'b': *, 'c': {0: *, 1: {}}}) + >>> treespec_transform( + ... treespec, + ... lambda spec: ( + ... treespec_ordereddict(zip(spec.entries(), spec.children())) + ... if spec.type is dict + ... else spec + ... ), + ... ) + PyTreeSpec(OrderedDict({'a': (*, [*, *]), 'b': *, 'c': (*, None)})) + >>> treespec_transform( + ... treespec, + ... lambda spec: ( + ... treespec_ordereddict(tree_unflatten(spec, spec.children())) + ... if spec.type is dict + ... else spec + ... ), + ... ) + PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': (*, None)})) >>> treespec_transform(treespec, lambda spec: treespec_tuple(spec.children())) PyTreeSpec(((*, (*, *)), *, (*, ()))) - >>> treespec_transform(treespec, lambda spec: treespec_list(spec.children()) if spec.type is tuple else spec) + >>> treespec_transform( + ... treespec, + ... lambda spec: ( + ... treespec_list(spec.children()) + ... if spec.type is tuple + ... else spec + ... ), + ... ) PyTreeSpec({'a': [*, [*, *]], 'b': *, 'c': [*, None]}) >>> treespec_transform(treespec, None, lambda spec: tree_structure((1, [2]))) PyTreeSpec({'a': ((*, [*]), [(*, [*]), (*, [*])]), 'b': (*, [*]), 'c': ((*, [*]), None)}) From 02e79158e73bf7d2294186b4eb3c9a30a62c6973 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 10 Dec 2024 18:33:05 +0800 Subject: [PATCH 6/7] chore(workflows): install build dependencies --- .github/workflows/tests-with-pydebug.yml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/.github/workflows/tests-with-pydebug.yml b/.github/workflows/tests-with-pydebug.yml index a5654c75..9e4a8ca5 100644 --- a/.github/workflows/tests-with-pydebug.yml +++ b/.github/workflows/tests-with-pydebug.yml @@ -77,6 +77,21 @@ jobs: git clone https://github.com/pyenv/pyenv.git "${PYENV_ROOT}" echo "PYENV_ROOT=${PYENV_ROOT}" >> "${GITHUB_ENV}" echo "PATH=${PATH}" >> "${GITHUB_ENV}" + if [[ "${{ runner.os }}" == 'Linux' ]]; then + sudo apt-get update -qq && sudo apt-get install -yqq --no-install-recommends \ + make \ + build-essential \ + libssl-dev \ + zlib1g-dev \ + libbz2-dev \ + libsqlite3-dev \ + libncurses-dev \ + libreadline-dev \ + libgdbm-dev \ + liblzma-dev + elif [[ "${{ runner.os }}" == 'macOS' ]]; then + brew install --only-dependencies python@3 + fi - name: Set up pyenv id: setup-pyenv-windows From 8e227516c79b3f50e996d9dc73f988bbcb27c035 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 10 Dec 2024 22:29:50 +0800 Subject: [PATCH 7/7] refactor(treespec): refactor `PyTreeSpec.walk` --- include/optree/exceptions.h | 2 +- include/optree/treespec.h | 8 ++--- optree/_C.pyi | 4 +-- src/optree.cpp | 12 +++---- src/treespec/traversal.cpp | 46 +++++++++++++++++-------- src/treespec/treespec.cpp | 2 +- tests/test_ops.py | 69 +++++++++++++++++++++++++++---------- 7 files changed, 97 insertions(+), 46 deletions(-) diff --git a/include/optree/exceptions.h b/include/optree/exceptions.h index 59672f0d..9ed3ee17 100644 --- a/include/optree/exceptions.h +++ b/include/optree/exceptions.h @@ -48,7 +48,7 @@ class InternalError : public std::logic_error { : InternalError([&message, &file, &lineno, &function]() -> std::string { std::ostringstream oss{}; oss << message << " ("; - if (function.has_value()) [[likely]] { + if (function) [[likely]] { oss << "function `" << *function << "` "; } oss << "at file " << file << ":" << lineno << ")\n\n" diff --git a/include/optree/treespec.h b/include/optree/treespec.h index 9b2e7ec4..1ab766d7 100644 --- a/include/optree/treespec.h +++ b/include/optree/treespec.h @@ -138,10 +138,10 @@ class PyTreeSpec { [[nodiscard]] std::unique_ptr Compose(const PyTreeSpec &inner_treespec) const; // Map a function over a PyTree structure, applying f_leaf to each leaf, and - // f_node(children, node_data) to each container node. - [[nodiscard]] py::object Walk(const py::function &f_node, - const std::optional &f_leaf, - const py::iterable &leaves) const; + // f_node(node_type, node_data, children) to each container node. + [[nodiscard]] py::object Walk(const py::iterable &leaves, + const std::optional &f_node = std::nullopt, + const std::optional &f_leaf = std::nullopt) const; // Return paths to all leaves in the PyTreeSpec. [[nodiscard]] std::vector Paths() const; diff --git a/optree/_C.pyi b/optree/_C.pyi index 32d1e7c9..1df7f036 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -119,9 +119,9 @@ class PyTreeSpec: def compose(self, inner_treespec: Self) -> Self: ... def walk( self, - f_node: Callable[[tuple[U, ...], MetaData], U], - f_leaf: Callable[[T], U] | None, leaves: Iterable[T], + f_node: Callable[[builtins.type, MetaData, tuple[U, ...]], U] | None = None, + f_leaf: Callable[[T], U] | None = None, ) -> U: ... def paths(self) -> list[tuple[Any, ...]]: ... def accessors(self) -> list[PyTreeAccessor]: ... diff --git a/src/optree.cpp b/src/optree.cpp index 756d6973..67c34432 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -31,7 +31,7 @@ py::module_ GetCxxModule(const std::optional& module) { PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store storage; return storage .call_once_and_store_result([&module]() -> py::module_ { - EXPECT_TRUE(module.has_value(), "The module must be provided."); + EXPECT_TRUE(module, "The module must be provided."); return *module; }) .get_stored(); @@ -212,11 +212,11 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] py::arg("inner_treespec")) .def("walk", &PyTreeSpec::Walk, - "Walk over the pytree structure, calling ``f_node(children, node_data)`` at nodes, " - "and ``f_leaf(leaf)`` at leaves.", - py::arg("f_node"), - py::arg("f_leaf"), - py::arg("leaves")) + "Walk over the pytree structure, calling ``f_node(node_type, node_data, children)`` " + "at nodes, and ``f_leaf(leaf)`` at leaves.", + py::arg("leaves"), + py::arg("f_node") = std::nullopt, + py::arg("f_leaf") = std::nullopt) .def("paths", &PyTreeSpec::Paths, "Return a list of paths to the leaves of the treespec.") .def("accessors", &PyTreeSpec::Accessors, diff --git a/src/treespec/traversal.cpp b/src/treespec/traversal.cpp index 95dc0873..ba23b849 100644 --- a/src/treespec/traversal.cpp +++ b/src/treespec/traversal.cpp @@ -18,6 +18,7 @@ limitations under the License. #include // std::optional #include // std::ostringstream #include // std::runtime_error +#include // std::move #include "optree/optree.h" @@ -171,9 +172,10 @@ py::object PyTreeIter::Next() { } } -py::object PyTreeSpec::Walk(const py::function& f_node, - const std::optional& f_leaf, - const py::iterable& leaves) const { +// NOLINTNEXTLINE[readability-function-cognitive-complexity] +py::object PyTreeSpec::Walk(const py::iterable& leaves, + const std::optional& f_node, + const std::optional& f_leaf) const { PYTREESPEC_SANITY_CHECK(*this); const scoped_critical_section cs{leaves}; @@ -203,18 +205,34 @@ py::object PyTreeSpec::Walk(const py::function& f_node, case PyTreeKind::Deque: case PyTreeKind::StructSequence: case PyTreeKind::Custom: { - EXPECT_GE(py::ssize_t_cast(agenda.size()), - node.arity, - "Too few elements for custom type."); - const py::tuple tuple{node.arity}; - for (ssize_t i = node.arity - 1; i >= 0; --i) { - TupleSetItem(tuple, i, agenda.back()); - agenda.pop_back(); + const ssize_t size = py::ssize_t_cast(agenda.size()); + EXPECT_GE(size, node.arity, "Too few elements for custom type."); + + if (f_node) [[likely]] { + const py::tuple children{node.arity}; + for (ssize_t i = node.arity - 1; i >= 0; --i) { + TupleSetItem(children, i, agenda.back()); + agenda.pop_back(); + } + + const py::object& node_type = GetType(node); + { + const scoped_critical_section cs2{node_type}; + agenda.emplace_back(EVALUATE_WITH_LOCK_HELD2( + (*f_node)(node_type, + (node.node_data ? node.node_data : py::none()), + children), + node.node_data, + *f_node)); + } + } else [[unlikely]] { + py::object out = + MakeNode(node, + (node.arity > 0 ? &agenda[size - node.arity] : nullptr), + node.arity); + agenda.resize(size - node.arity); + agenda.emplace_back(std::move(out)); } - agenda.emplace_back(EVALUATE_WITH_LOCK_HELD2( - f_node(tuple, (node.node_data ? node.node_data : py::none())), - node.node_data, - f_node)); break; } diff --git a/src/treespec/treespec.cpp b/src/treespec/treespec.cpp index 7aa5d5bd..a022f1d0 100644 --- a/src/treespec/treespec.cpp +++ b/src/treespec/treespec.cpp @@ -979,7 +979,7 @@ std::unique_ptr PyTreeSpec::Child(ssize_t index) const { } py::object PyTreeSpec::GetType(const std::optional& node) const { - if (!node.has_value()) [[likely]] { + if (!node) [[likely]] { PYTREESPEC_SANITY_CHECK(*this); } diff --git a/tests/test_ops.py b/tests/test_ops.py index 48a7f6b1..5622a8e2 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -247,34 +247,46 @@ def test_walk(): # X # + def unflatten_node(node_type, node_data, children): + return optree.register_pytree_node.get(node_type).unflatten_func(node_data, children) + def get_functions(): - nodes_visited = [] + node_types_visited = [] node_data_visited = [] + nodes_visited = [] leaves_visited = [] - def f_node(node, node_data): - nodes_visited.append(node) + def f_node(node_type, node_data, node): + node_types_visited.append(node_type) node_data_visited.append(node_data) + nodes_visited.append(node) return copy.deepcopy(nodes_visited), None def f_leaf(leaf): leaves_visited.append(leaf) return copy.deepcopy(leaves_visited) - return f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited + return f_node, f_leaf, leaves_visited, nodes_visited, node_data_visited, node_types_visited leaves, treespec = optree.tree_flatten(tree) - f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited = get_functions() + f_node, f_leaf, *_ = get_functions() with pytest.raises(ValueError, match='Too few leaves for PyTreeSpec.'): - treespec.walk(f_node, f_leaf, leaves[:-1]) + treespec.walk(leaves[:-1], f_node, f_leaf) - f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited = get_functions() + f_node, f_leaf, *_ = get_functions() with pytest.raises(ValueError, match='Too many leaves for PyTreeSpec.'): - treespec.walk(f_node, f_leaf, (*leaves, 0)) - - f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited = get_functions() - output = treespec.walk(f_node, f_leaf, leaves) + treespec.walk((*leaves, 0), f_node, f_leaf) + + ( + f_node, + f_leaf, + leaves_visited, + nodes_visited, + node_data_visited, + node_types_visited, + ) = get_functions() + output = treespec.walk(leaves, f_node, f_leaf) assert leaves_visited == [1, 2, 3, 4] assert nodes_visited == [ (), @@ -282,6 +294,7 @@ def f_leaf(leaf): ([1], [1, 2], ([(), ([1, 2, 3], ([()], None), [1, 2, 3, 4])], None)), ] assert node_data_visited == [None, ['e', 'f', 'g'], ['a', 'b', 'c']] + assert node_types_visited == [type(None), dict, dict] assert output == ( [ (), @@ -291,24 +304,36 @@ def f_leaf(leaf): None, ) + assert treespec.walk(leaves) == tree + assert treespec.walk(leaves, unflatten_node, None) == tree + assert treespec.walk(leaves, None, lambda x: x + 1) == optree.tree_map(lambda x: x + 1, tree) + leaves, treespec = optree.tree_flatten(tree, none_is_leaf=True) - f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited = get_functions() + f_node, f_leaf, *_ = get_functions() with pytest.raises(ValueError, match='Too few leaves for PyTreeSpec.'): - treespec.walk(f_node, f_leaf, leaves[:-1]) + treespec.walk(leaves[:-1], f_node, f_leaf) - f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited = get_functions() + f_node, f_leaf, *_ = get_functions() with pytest.raises(ValueError, match='Too many leaves for PyTreeSpec.'): - treespec.walk(f_node, f_leaf, (*leaves, 0)) - - f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited = get_functions() - output = treespec.walk(f_node, f_leaf, leaves) + treespec.walk((*leaves, 0), f_node, f_leaf) + + ( + f_node, + f_leaf, + leaves_visited, + nodes_visited, + node_data_visited, + node_types_visited, + ) = get_functions() + output = treespec.walk(leaves, f_node, f_leaf) assert leaves_visited == [1, 2, 3, None, 4] assert nodes_visited == [ ([1, 2, 3], [1, 2, 3, None], [1, 2, 3, None, 4]), ([1], [1, 2], ([([1, 2, 3], [1, 2, 3, None], [1, 2, 3, None, 4])], None)), ] assert node_data_visited == [['e', 'f', 'g'], ['a', 'b', 'c']] + assert node_types_visited == [dict, dict] assert output == ( [ ([1, 2, 3], [1, 2, 3, None], [1, 2, 3, None, 4]), @@ -317,6 +342,14 @@ def f_leaf(leaf): None, ) + assert treespec.walk(leaves) == tree + assert treespec.walk(leaves, unflatten_node, None) == tree + assert treespec.walk(leaves, None, lambda x: (x,)) == optree.tree_map( + lambda x: (x,), + tree, + none_is_leaf=True, + ) + def test_flatten_up_to(): treespec = optree.tree_structure([(1, 2), None, CustomTuple(foo=3, bar=7)])