Skip to content

Commit

Permalink
refactor(treespec): refactor PyTreeSpec.walk
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Dec 10, 2024
1 parent 02e7915 commit 31cbbae
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 43 deletions.
8 changes: 4 additions & 4 deletions include/optree/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ class PyTreeSpec {
[[nodiscard]] std::unique_ptr<PyTreeSpec> Compose(const PyTreeSpec &inner_treespec) const;

// Map a function over a PyTree structure, applying f_leaf to each leaf, and
// f_node(children, node_data) to each container node.
[[nodiscard]] py::object Walk(const py::function &f_node,
const std::optional<py::function> &f_leaf,
const py::iterable &leaves) const;
// f_node(node_type, node_data, children) to each container node.
[[nodiscard]] py::object Walk(const py::iterable &leaves,
const std::optional<py::function> &f_node = std::nullopt,
const std::optional<py::function> &f_leaf = std::nullopt) const;

// Return paths to all leaves in the PyTreeSpec.
[[nodiscard]] std::vector<py::tuple> Paths() const;
Expand Down
4 changes: 2 additions & 2 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ class PyTreeSpec:
def compose(self, inner_treespec: Self) -> Self: ...
def walk(
self,
f_node: Callable[[tuple[U, ...], MetaData], U],
f_leaf: Callable[[T], U] | None,
leaves: Iterable[T],
f_node: Callable[[builtins.type, MetaData, tuple[U, ...]], U] | None = None,
f_leaf: Callable[[T], U] | None = None,
) -> U: ...
def paths(self) -> list[tuple[Any, ...]]: ...
def accessors(self) -> list[PyTreeAccessor]: ...
Expand Down
10 changes: 5 additions & 5 deletions src/optree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,11 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
py::arg("inner_treespec"))
.def("walk",
&PyTreeSpec::Walk,
"Walk over the pytree structure, calling ``f_node(children, node_data)`` at nodes, "
"and ``f_leaf(leaf)`` at leaves.",
py::arg("f_node"),
py::arg("f_leaf"),
py::arg("leaves"))
"Walk over the pytree structure, calling ``f_node(node_type, node_data, children)`` "
"at nodes, and ``f_leaf(leaf)`` at leaves.",
py::arg("leaves"),
py::arg("f_node") = std::nullopt,
py::arg("f_leaf") = std::nullopt)
.def("paths", &PyTreeSpec::Paths, "Return a list of paths to the leaves of the treespec.")
.def("accessors",
&PyTreeSpec::Accessors,
Expand Down
46 changes: 32 additions & 14 deletions src/treespec/traversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <optional> // std::optional
#include <sstream> // std::ostringstream
#include <stdexcept> // std::runtime_error
#include <utility> // std::move

#include "optree/optree.h"

Expand Down Expand Up @@ -171,9 +172,10 @@ py::object PyTreeIter::Next() {
}
}

py::object PyTreeSpec::Walk(const py::function& f_node,
const std::optional<py::function>& f_leaf,
const py::iterable& leaves) const {
// NOLINTNEXTLINE[readability-function-cognitive-complexity]
py::object PyTreeSpec::Walk(const py::iterable& leaves,
const std::optional<py::function>& f_node,
const std::optional<py::function>& f_leaf) const {
PYTREESPEC_SANITY_CHECK(*this);

const scoped_critical_section cs{leaves};
Expand Down Expand Up @@ -203,18 +205,34 @@ py::object PyTreeSpec::Walk(const py::function& f_node,
case PyTreeKind::Deque:
case PyTreeKind::StructSequence:
case PyTreeKind::Custom: {
EXPECT_GE(py::ssize_t_cast(agenda.size()),
node.arity,
"Too few elements for custom type.");
const py::tuple tuple{node.arity};
for (ssize_t i = node.arity - 1; i >= 0; --i) {
TupleSetItem(tuple, i, agenda.back());
agenda.pop_back();
const ssize_t size = py::ssize_t_cast(agenda.size());
EXPECT_GE(size, node.arity, "Too few elements for custom type.");

if (f_node) [[likely]] {
const py::tuple children{node.arity};
for (ssize_t i = node.arity - 1; i >= 0; --i) {
TupleSetItem(children, i, agenda.back());
agenda.pop_back();
}

const py::object& node_type = GetType(node);
{
const scoped_critical_section cs{node_type};
agenda.emplace_back(EVALUATE_WITH_LOCK_HELD2(
(*f_node)(node_type,
(node.node_data ? node.node_data : py::none()),
children),
node.node_data,
*f_node));
}
} else [[unlikely]] {
py::object out =
MakeNode(node,
(node.arity > 0 ? &agenda[size - node.arity] : nullptr),
node.arity);
agenda.resize(size - node.arity);
agenda.emplace_back(std::move(out));
}
agenda.emplace_back(EVALUATE_WITH_LOCK_HELD2(
f_node(tuple, (node.node_data ? node.node_data : py::none())),
node.node_data,
f_node));
break;
}

Expand Down
69 changes: 51 additions & 18 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,41 +247,54 @@ def test_walk():
# X
#

def unflatten_node(node_type, node_data, children):
return optree.register_pytree_node.get(node_type).unflatten_func(node_data, children)

def get_functions():
nodes_visited = []
node_types_visited = []
node_data_visited = []
nodes_visited = []
leaves_visited = []

def f_node(node, node_data):
nodes_visited.append(node)
def f_node(node_type, node_data, node):
node_types_visited.append(node_type)
node_data_visited.append(node_data)
nodes_visited.append(node)
return copy.deepcopy(nodes_visited), None

def f_leaf(leaf):
leaves_visited.append(leaf)
return copy.deepcopy(leaves_visited)

return f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited
return f_node, f_leaf, leaves_visited, nodes_visited, node_data_visited, node_types_visited

leaves, treespec = optree.tree_flatten(tree)

f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited = get_functions()
f_node, f_leaf, *_ = get_functions()
with pytest.raises(ValueError, match='Too few leaves for PyTreeSpec.'):
treespec.walk(f_node, f_leaf, leaves[:-1])
treespec.walk(leaves[:-1], f_node, f_leaf)

f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited = get_functions()
f_node, f_leaf, *_ = get_functions()
with pytest.raises(ValueError, match='Too many leaves for PyTreeSpec.'):
treespec.walk(f_node, f_leaf, (*leaves, 0))

f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited = get_functions()
output = treespec.walk(f_node, f_leaf, leaves)
treespec.walk((*leaves, 0), f_node, f_leaf)

(
f_node,
f_leaf,
leaves_visited,
nodes_visited,
node_data_visited,
node_types_visited,
) = get_functions()
output = treespec.walk(leaves, f_node, f_leaf)
assert leaves_visited == [1, 2, 3, 4]
assert nodes_visited == [
(),
([1, 2, 3], ([()], None), [1, 2, 3, 4]),
([1], [1, 2], ([(), ([1, 2, 3], ([()], None), [1, 2, 3, 4])], None)),
]
assert node_data_visited == [None, ['e', 'f', 'g'], ['a', 'b', 'c']]
assert node_types_visited == [type(None), dict, dict]
assert output == (
[
(),
Expand All @@ -291,24 +304,36 @@ def f_leaf(leaf):
None,
)

assert treespec.walk(leaves) == tree
assert treespec.walk(leaves, unflatten_node, None) == tree
assert treespec.walk(leaves, None, lambda x: x + 1) == optree.tree_map(lambda x: x + 1, tree)

leaves, treespec = optree.tree_flatten(tree, none_is_leaf=True)

f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited = get_functions()
f_node, f_leaf, *_ = get_functions()
with pytest.raises(ValueError, match='Too few leaves for PyTreeSpec.'):
treespec.walk(f_node, f_leaf, leaves[:-1])
treespec.walk(leaves[:-1], f_node, f_leaf)

f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited = get_functions()
f_node, f_leaf, *_ = get_functions()
with pytest.raises(ValueError, match='Too many leaves for PyTreeSpec.'):
treespec.walk(f_node, f_leaf, (*leaves, 0))

f_node, f_leaf, nodes_visited, node_data_visited, leaves_visited = get_functions()
output = treespec.walk(f_node, f_leaf, leaves)
treespec.walk((*leaves, 0), f_node, f_leaf)

(
f_node,
f_leaf,
leaves_visited,
nodes_visited,
node_data_visited,
node_types_visited,
) = get_functions()
output = treespec.walk(leaves, f_node, f_leaf)
assert leaves_visited == [1, 2, 3, None, 4]
assert nodes_visited == [
([1, 2, 3], [1, 2, 3, None], [1, 2, 3, None, 4]),
([1], [1, 2], ([([1, 2, 3], [1, 2, 3, None], [1, 2, 3, None, 4])], None)),
]
assert node_data_visited == [['e', 'f', 'g'], ['a', 'b', 'c']]
assert node_types_visited == [dict, dict]
assert output == (
[
([1, 2, 3], [1, 2, 3, None], [1, 2, 3, None, 4]),
Expand All @@ -317,6 +342,14 @@ def f_leaf(leaf):
None,
)

assert treespec.walk(leaves) == tree
assert treespec.walk(leaves, unflatten_node, None) == tree
assert treespec.walk(leaves, None, lambda x: (x,)) == optree.tree_map(
lambda x: (x,),
tree,
none_is_leaf=True,
)


def test_flatten_up_to():
treespec = optree.tree_structure([(1, 2), None, CustomTuple(foo=3, bar=7)])
Expand Down

0 comments on commit 31cbbae

Please sign in to comment.