diff --git a/include/optree/treespec.h b/include/optree/treespec.h index 9b2e7ec4..1ab766d7 100644 --- a/include/optree/treespec.h +++ b/include/optree/treespec.h @@ -138,10 +138,10 @@ class PyTreeSpec { [[nodiscard]] std::unique_ptr Compose(const PyTreeSpec &inner_treespec) const; // Map a function over a PyTree structure, applying f_leaf to each leaf, and - // f_node(children, node_data) to each container node. - [[nodiscard]] py::object Walk(const py::function &f_node, - const std::optional &f_leaf, - const py::iterable &leaves) const; + // f_node(node_type, node_data, children) to each container node. + [[nodiscard]] py::object Walk(const py::iterable &leaves, + const std::optional &f_node = std::nullopt, + const std::optional &f_leaf = std::nullopt) const; // Return paths to all leaves in the PyTreeSpec. [[nodiscard]] std::vector Paths() const; diff --git a/optree/_C.pyi b/optree/_C.pyi index 32d1e7c9..1df7f036 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -119,9 +119,9 @@ class PyTreeSpec: def compose(self, inner_treespec: Self) -> Self: ... def walk( self, - f_node: Callable[[tuple[U, ...], MetaData], U], - f_leaf: Callable[[T], U] | None, leaves: Iterable[T], + f_node: Callable[[builtins.type, MetaData, tuple[U, ...]], U] | None = None, + f_leaf: Callable[[T], U] | None = None, ) -> U: ... def paths(self) -> list[tuple[Any, ...]]: ... def accessors(self) -> list[PyTreeAccessor]: ... diff --git a/src/optree.cpp b/src/optree.cpp index 756d6973..a82ff5f3 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -212,11 +212,11 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] py::arg("inner_treespec")) .def("walk", &PyTreeSpec::Walk, - "Walk over the pytree structure, calling ``f_node(children, node_data)`` at nodes, " - "and ``f_leaf(leaf)`` at leaves.", - py::arg("f_node"), - py::arg("f_leaf"), - py::arg("leaves")) + "Walk over the pytree structure, calling ``f_node(node_type, node_data, children)`` " + "at nodes, and ``f_leaf(leaf)`` at leaves.", + py::arg("leaves"), + py::arg("f_node") = std::nullopt, + py::arg("f_leaf") = std::nullopt) .def("paths", &PyTreeSpec::Paths, "Return a list of paths to the leaves of the treespec.") .def("accessors", &PyTreeSpec::Accessors, diff --git a/src/treespec/traversal.cpp b/src/treespec/traversal.cpp index 95dc0873..86c4ae99 100644 --- a/src/treespec/traversal.cpp +++ b/src/treespec/traversal.cpp @@ -18,6 +18,7 @@ limitations under the License. #include // std::optional #include // std::ostringstream #include // std::runtime_error +#include // std::move #include "optree/optree.h" @@ -171,9 +172,10 @@ py::object PyTreeIter::Next() { } } -py::object PyTreeSpec::Walk(const py::function& f_node, - const std::optional& f_leaf, - const py::iterable& leaves) const { +// NOLINTNEXTLINE[readability-function-cognitive-complexity] +py::object PyTreeSpec::Walk(const py::iterable& leaves, + const std::optional& f_node, + const std::optional& f_leaf) const { PYTREESPEC_SANITY_CHECK(*this); const scoped_critical_section cs{leaves}; @@ -203,18 +205,34 @@ py::object PyTreeSpec::Walk(const py::function& f_node, case PyTreeKind::Deque: case PyTreeKind::StructSequence: case PyTreeKind::Custom: { - EXPECT_GE(py::ssize_t_cast(agenda.size()), - node.arity, - "Too few elements for custom type."); - const py::tuple tuple{node.arity}; - for (ssize_t i = node.arity - 1; i >= 0; --i) { - TupleSetItem(tuple, i, agenda.back()); - agenda.pop_back(); + const ssize_t size = py::ssize_t_cast(agenda.size()); + EXPECT_GE(size, node.arity, "Too few elements for custom type."); + + if (f_node) [[likely]] { + const py::tuple children{node.arity}; + for (ssize_t i = node.arity - 1; i >= 0; --i) { + TupleSetItem(children, i, agenda.back()); + agenda.pop_back(); + } + + const py::object& node_type = GetType(node); + { + const scoped_critical_section cs{node_type}; + agenda.emplace_back(EVALUATE_WITH_LOCK_HELD2( + (*f_node)(node_type, + (node.node_data ? node.node_data : py::none()), + children), + node.node_data, + *f_node)); + } + } else [[unlikely]] { + py::object out = + MakeNode(node, + (node.arity > 0 ? &agenda[size - node.arity] : nullptr), + node.arity); + agenda.resize(size - node.arity); + agenda.emplace_back(std::move(out)); } - agenda.emplace_back(EVALUATE_WITH_LOCK_HELD2( - f_node(tuple, (node.node_data ? node.node_data : py::none())), - node.node_data, - f_node)); break; } diff --git a/tests/test_ops.py b/tests/test_ops.py index 48a7f6b1..5622a8e2 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -247,34 +247,46 @@ def test_walk(): # X # + def unflatten_node(node_type, node_data, children): + return optree.register_pytree_node.get(node_type).unflatten_func(node_data, children) + def get_functions(): - nodes_visited = [] + node_types_visited = [] node_data_visited = [] + nodes_visited = [] leaves_visited = [] - def f_node(node, node_data): - nodes_visited.append(node) + def f_node(node_type, node_data, node): + node_types_visited.append(node_type) node_data_visited.append(node_data) + nodes_visited.append(node) return copy.deepcopy(nodes_visited), None def f_leaf(leaf): leaves_visited.append(leaf) return copy.deepcopy(leaves_visited) - return f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited + return f_node, f_leaf, leaves_visited, nodes_visited, node_data_visited, node_types_visited leaves, treespec = optree.tree_flatten(tree) - f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited = get_functions() + f_node, f_leaf, *_ = get_functions() with pytest.raises(ValueError, match='Too few leaves for PyTreeSpec.'): - treespec.walk(f_node, f_leaf, leaves[:-1]) + treespec.walk(leaves[:-1], f_node, f_leaf) - f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited = get_functions() + f_node, f_leaf, *_ = get_functions() with pytest.raises(ValueError, match='Too many leaves for PyTreeSpec.'): - treespec.walk(f_node, f_leaf, (*leaves, 0)) - - f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited = get_functions() - output = treespec.walk(f_node, f_leaf, leaves) + treespec.walk((*leaves, 0), f_node, f_leaf) + + ( + f_node, + f_leaf, + leaves_visited, + nodes_visited, + node_data_visited, + node_types_visited, + ) = get_functions() + output = treespec.walk(leaves, f_node, f_leaf) assert leaves_visited == [1, 2, 3, 4] assert nodes_visited == [ (), @@ -282,6 +294,7 @@ def f_leaf(leaf): ([1], [1, 2], ([(), ([1, 2, 3], ([()], None), [1, 2, 3, 4])], None)), ] assert node_data_visited == [None, ['e', 'f', 'g'], ['a', 'b', 'c']] + assert node_types_visited == [type(None), dict, dict] assert output == ( [ (), @@ -291,24 +304,36 @@ def f_leaf(leaf): None, ) + assert treespec.walk(leaves) == tree + assert treespec.walk(leaves, unflatten_node, None) == tree + assert treespec.walk(leaves, None, lambda x: x + 1) == optree.tree_map(lambda x: x + 1, tree) + leaves, treespec = optree.tree_flatten(tree, none_is_leaf=True) - f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited = get_functions() + f_node, f_leaf, *_ = get_functions() with pytest.raises(ValueError, match='Too few leaves for PyTreeSpec.'): - treespec.walk(f_node, f_leaf, leaves[:-1]) + treespec.walk(leaves[:-1], f_node, f_leaf) - f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited = get_functions() + f_node, f_leaf, *_ = get_functions() with pytest.raises(ValueError, match='Too many leaves for PyTreeSpec.'): - treespec.walk(f_node, f_leaf, (*leaves, 0)) - - f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited = get_functions() - output = treespec.walk(f_node, f_leaf, leaves) + treespec.walk((*leaves, 0), f_node, f_leaf) + + ( + f_node, + f_leaf, + leaves_visited, + nodes_visited, + node_data_visited, + node_types_visited, + ) = get_functions() + output = treespec.walk(leaves, f_node, f_leaf) assert leaves_visited == [1, 2, 3, None, 4] assert nodes_visited == [ ([1, 2, 3], [1, 2, 3, None], [1, 2, 3, None, 4]), ([1], [1, 2], ([([1, 2, 3], [1, 2, 3, None], [1, 2, 3, None, 4])], None)), ] assert node_data_visited == [['e', 'f', 'g'], ['a', 'b', 'c']] + assert node_types_visited == [dict, dict] assert output == ( [ ([1, 2, 3], [1, 2, 3, None], [1, 2, 3, None, 4]), @@ -317,6 +342,14 @@ def f_leaf(leaf): None, ) + assert treespec.walk(leaves) == tree + assert treespec.walk(leaves, unflatten_node, None) == tree + assert treespec.walk(leaves, None, lambda x: (x,)) == optree.tree_map( + lambda x: (x,), + tree, + none_is_leaf=True, + ) + def test_flatten_up_to(): treespec = optree.tree_structure([(1, 2), None, CustomTuple(foo=3, bar=7)])